Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ad7c8b0587 | |||
| 1625997d8f |
228
agent_forward.go
228
agent_forward.go
@ -4,7 +4,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
@ -15,24 +17,53 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|||||||
return sshagent.RequestAgentForwarding(session)
|
return sshagent.RequestAgentForwarding(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||||
return sshagent.ForwardToAgent(client, keyring)
|
|
||||||
|
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return startSSHAgentForwardProxy(client, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
var probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
conn, err := dialSSHAgent(timeout)
|
conn, err := dialSSHAgent(timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
}
|
}
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||||
}
|
}
|
||||||
return sshagent.NewClient(conn), conn, nil
|
return conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||||
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||||
|
|
||||||
|
type sshAgentForwardProxy struct {
|
||||||
|
stopOnce sync.Once
|
||||||
|
stopCh chan struct{}
|
||||||
|
|
||||||
|
activeMu sync.Mutex
|
||||||
|
active map[*sshAgentForwardBridge]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) Close() error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.stopOnce.Do(func() {
|
||||||
|
close(p.stopCh)
|
||||||
|
})
|
||||||
|
p.closeActive()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentForwardBridge struct {
|
||||||
|
proxy *sshAgentForwardProxy
|
||||||
|
channel ssh.Channel
|
||||||
|
conn net.Conn
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return errors.New("ssh client is nil")
|
return errors.New("ssh client is nil")
|
||||||
@ -80,20 +111,21 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
|
timeout := effectiveDialTimeout(s.LoginInfo)
|
||||||
if err != nil {
|
if err := probeSSHAgentForwarding(timeout); err != nil {
|
||||||
return wrapSSHAgentForwardingUnavailable(err)
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
}
|
}
|
||||||
if s.closing.Load() {
|
if s.closing.Load() {
|
||||||
_ = closer.Close()
|
|
||||||
return errSSHClientClosing
|
return errSSHClientClosing
|
||||||
}
|
}
|
||||||
if err := routeSSHAgentForwarding(client, keyring); err != nil {
|
closer, err := routeSSHAgentForwarding(client, timeout)
|
||||||
_ = closer.Close()
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !s.canAttachAgentForwarder(client) {
|
if !s.canAttachAgentForwarder(client) {
|
||||||
_ = closer.Close()
|
if closer != nil {
|
||||||
|
_ = closer.Close()
|
||||||
|
}
|
||||||
return errSSHClientClosing
|
return errSSHClientClosing
|
||||||
}
|
}
|
||||||
s.agentForwarder = closer
|
s.agentForwarder = closer
|
||||||
@ -149,3 +181,175 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||||
|
if channels == nil {
|
||||||
|
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-proxy.stopCh:
|
||||||
|
return
|
||||||
|
case ch, ok := <-channels:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go handleSSHAgentForwardChannel(proxy, ch, timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return proxy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, err := dialSSHAgent(timeout)
|
||||||
|
if err != nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, reqs, err := ch.Accept()
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
bridge.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) run() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer b.unregister()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(b.channel, b.conn)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(b.conn, b.channel)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) close() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.closeOnce.Do(func() {
|
||||||
|
closeWriter(b.channel)
|
||||||
|
closeWriter(b.conn)
|
||||||
|
if b.channel != nil {
|
||||||
|
_ = b.channel.Close()
|
||||||
|
}
|
||||||
|
if b.conn != nil {
|
||||||
|
_ = b.conn.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) unregister() {
|
||||||
|
if b == nil || b.proxy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.proxy.unregisterBridge(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.stopCh:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if p.active == nil {
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
}
|
||||||
|
p.active[bridge] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
delete(p.active, bridge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) closeActive() {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||||
|
for bridge := range p.active {
|
||||||
|
active = append(active, bridge)
|
||||||
|
}
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
p.activeMu.Unlock()
|
||||||
|
|
||||||
|
for _, bridge := range active {
|
||||||
|
bridge.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWriter(value any) {
|
||||||
|
type closeWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
if cw, ok := value.(closeWriter); ok {
|
||||||
|
_ = cw.CloseWrite()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
package starssh
|
package starssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
sshagent "golang.org/x/crypto/ssh/agent"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type testCloser struct {
|
type testCloser struct {
|
||||||
@ -20,15 +22,90 @@ func (c *testCloser) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
closed atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.closed.Add(1)
|
||||||
|
if c.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type testSSHChannel struct {
|
||||||
|
readFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
stderr bytes.Buffer
|
||||||
|
closed atomic.Int32
|
||||||
|
closeOnce sync.Once
|
||||||
|
closeCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||||
|
return &testSSHChannel{
|
||||||
|
readFunc: readFunc,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingTestSSHChannel() *testSSHChannel {
|
||||||
|
ch := newTestSSHChannel(nil)
|
||||||
|
ch.readFunc = func(p []byte) (int, error) {
|
||||||
|
<-ch.closeCh
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if c.readFunc != nil {
|
||||||
|
return c.readFunc(p)
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Close() error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closed.Add(1)
|
||||||
|
close(c.closeCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) CloseWrite() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
||||||
|
return &c.stderr
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
oldCloseSSHClient := closeSSHClient
|
oldCloseSSHClient := closeSSHClient
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
closeSSHClient = oldCloseSSHClient
|
closeSSHClient = oldCloseSSHClient
|
||||||
@ -50,26 +127,26 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var agentInitCalls atomic.Int32
|
var probeCalls atomic.Int32
|
||||||
closer := &testCloser{}
|
closer := &testCloser{}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
agentInitCalls.Add(1)
|
probeCalls.Add(1)
|
||||||
if timeout != time.Second {
|
if timeout != time.Second {
|
||||||
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
||||||
}
|
}
|
||||||
return sshagent.NewKeyring(), closer, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var routeCalls atomic.Int32
|
var routeCalls atomic.Int32
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
routeCalls.Add(1)
|
routeCalls.Add(1)
|
||||||
if client != baseClient {
|
if client != baseClient {
|
||||||
t.Fatalf("unexpected routed client %p", client)
|
t.Fatalf("unexpected routed client %p", client)
|
||||||
}
|
}
|
||||||
if keyring == nil {
|
if timeout != time.Second {
|
||||||
t.Fatal("expected non-nil forwarded agent keyring")
|
t.Fatalf("unexpected routed timeout: %v", timeout)
|
||||||
}
|
}
|
||||||
return nil
|
return closer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestCalls atomic.Int32
|
var requestCalls atomic.Int32
|
||||||
@ -88,8 +165,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
t.Fatalf("second exec session: %v", err)
|
t.Fatalf("second exec session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if agentInitCalls.Load() != 1 {
|
if probeCalls.Load() != 1 {
|
||||||
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
|
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
||||||
}
|
}
|
||||||
if routeCalls.Load() != 1 {
|
if routeCalls.Load() != 1 {
|
||||||
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||||
@ -110,13 +187,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldRequestSessionPTY := requestSessionPTY
|
oldRequestSessionPTY := requestSessionPTY
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
requestSessionPTY = oldRequestSessionPTY
|
requestSessionPTY = oldRequestSessionPTY
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@ -138,10 +215,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
|
|
||||||
var requestCalls atomic.Int32
|
var requestCalls atomic.Int32
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
@ -162,11 +241,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -176,9 +255,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
t.Fatal("agent forwarder should not initialize when disabled")
|
t.Fatal("agent forwarding probe should not run when disabled")
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
t.Fatal("agent forwarding should not be requested when disabled")
|
t.Fatal("agent forwarding should not be requested when disabled")
|
||||||
@ -191,18 +270,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
}
|
}
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
t.Fatal("session request should not run when agent forwarder init fails")
|
t.Fatal("session request should not run when agent forwarder init fails")
|
||||||
@ -216,16 +295,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||||
@ -235,11 +314,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@ -247,10 +326,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
|||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
return errors.New("forwarding request denied")
|
return errors.New("forwarding request denied")
|
||||||
}
|
}
|
||||||
@ -263,12 +344,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@ -283,10 +364,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
return errors.New("forwarding request denied")
|
return errors.New("forwarding request denied")
|
||||||
}
|
}
|
||||||
@ -298,10 +381,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{
|
star := &StarSSH{
|
||||||
@ -314,8 +397,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := star.NewExecSession(); err != nil {
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
@ -325,10 +408,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{
|
star := &StarSSH{
|
||||||
@ -341,8 +424,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := star.NewExecSession(); err != nil {
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
@ -351,11 +434,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldCloseSSHClient := closeSSHClient
|
oldCloseSSHClient := closeSSHClient
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
closeSSHClient = oldCloseSSHClient
|
closeSSHClient = oldCloseSSHClient
|
||||||
})
|
})
|
||||||
@ -370,13 +453,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
|||||||
started := make(chan struct{})
|
started := make(chan struct{})
|
||||||
release := make(chan struct{})
|
release := make(chan struct{})
|
||||||
closer := &testCloser{}
|
closer := &testCloser{}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
close(started)
|
close(started)
|
||||||
<-release
|
<-release
|
||||||
return sshagent.NewKeyring(), closer, nil
|
return closer, nil
|
||||||
}
|
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||||
|
|
||||||
@ -415,3 +498,75 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
|||||||
t.Fatal("expected no leaked agent forwarder after close race")
|
t.Fatal("expected no leaked agent forwarder after close race")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
proxySSHAgentChannel(channel, tracked)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected local agent connection to be closed")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected ssh channel to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newBlockingTestSSHChannel()
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: tracked,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
t.Fatal("expected bridge registration to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
bridge.run()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
t.Fatalf("close proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("bridge did not exit after proxy close")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close local agent connection")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close ssh channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
login.go
34
login.go
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||||
|
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||||
|
|
||||||
var defaultAuthOrder = []AuthMethodKind{
|
var defaultAuthOrder = []AuthMethodKind{
|
||||||
AuthMethodSSHAgent,
|
AuthMethodSSHAgent,
|
||||||
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
return nil, ErrHostKeyCallbackRequired
|
return nil, ErrHostKeyCallbackRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
authTimeout := effectiveLoginTimeout(info)
|
||||||
|
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshInfo := &StarSSH{
|
sshInfo := &StarSSH{
|
||||||
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
clientConfig := &ssh.ClientConfig{
|
clientConfig := &ssh.ClientConfig{
|
||||||
User: info.User,
|
User: info.User,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Timeout: info.Timeout,
|
Timeout: authTimeout,
|
||||||
HostKeyCallback: hostKeyCallback,
|
HostKeyCallback: hostKeyCallback,
|
||||||
BannerCallback: bannerCallback,
|
BannerCallback: bannerCallback,
|
||||||
}
|
}
|
||||||
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return sshInfo, err
|
return sshInfo, err
|
||||||
}
|
}
|
||||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
|
||||||
defer restoreDeadline()
|
defer restoreDeadline()
|
||||||
|
|
||||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||||
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
|
|||||||
Addr: host,
|
Addr: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
|
DialTimeout: timeout,
|
||||||
User: user,
|
User: user,
|
||||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||||
}
|
}
|
||||||
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
|
|||||||
if info.Port <= 0 {
|
if info.Port <= 0 {
|
||||||
info.Port = defaultSSHPort
|
info.Port = defaultSSHPort
|
||||||
}
|
}
|
||||||
if info.Timeout <= 0 {
|
|
||||||
info.Timeout = defaultLoginTimeout
|
|
||||||
}
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||||
|
if info.Timeout <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return info.Timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||||
|
switch {
|
||||||
|
case info.DialTimeout < 0:
|
||||||
|
return 0
|
||||||
|
case info.DialTimeout > 0:
|
||||||
|
return info.DialTimeout
|
||||||
|
case info.Timeout > 0:
|
||||||
|
return info.Timeout
|
||||||
|
default:
|
||||||
|
return defaultLoginTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
|||||||
if info.DisableSSHAgent {
|
if info.DisableSSHAgent {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
agentErr = err
|
agentErr = err
|
||||||
continue
|
continue
|
||||||
|
|||||||
99
login_timeout_test.go
Normal file
99
login_timeout_test.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||||
|
info := normalizeLoginInput(LoginInput{})
|
||||||
|
if info.Port != defaultSSHPort {
|
||||||
|
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||||
|
}
|
||||||
|
if info.Timeout != 0 {
|
||||||
|
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||||
|
}
|
||||||
|
if info.DialTimeout != 0 {
|
||||||
|
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||||
|
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||||
|
}
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||||
|
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveDialTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info LoginInput
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default fallback",
|
||||||
|
info: LoginInput{},
|
||||||
|
want: defaultLoginTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reuse timeout when dial timeout omitted",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second},
|
||||||
|
want: 9 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit dial timeout wins",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||||
|
want: 3 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative dial timeout disables default dial deadline",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||||
|
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
captured := time.Duration(-2)
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||||||
|
captured = timeout
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info := LoginInput{
|
||||||
|
Timeout: 0,
|
||||||
|
DialTimeout: 11 * time.Second,
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}
|
||||||
|
auth, cleanup, err := buildAuthMethods(info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethods: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||||
|
}
|
||||||
|
if captured != 11*time.Second {
|
||||||
|
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: info.Timeout,
|
Timeout: effectiveDialTimeout(info),
|
||||||
}
|
}
|
||||||
return dialer.DialContext
|
return dialer.DialContext
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialContext := resolveDialContext(info)
|
dialContext := resolveDialContext(info)
|
||||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||||
if proxyConfig != nil {
|
if proxyConfig != nil {
|
||||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||||
}
|
}
|
||||||
|
|||||||
22
types.go
22
types.go
@ -94,14 +94,20 @@ type LoginInput struct {
|
|||||||
AuthOrder []AuthMethodKind
|
AuthOrder []AuthMethodKind
|
||||||
Addr string
|
Addr string
|
||||||
Port int
|
Port int
|
||||||
Timeout time.Duration
|
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||||
DialContext DialContextFunc
|
// already been established. Zero means no authentication timeout.
|
||||||
Proxy *ProxyConfig
|
Timeout time.Duration
|
||||||
Jump *LoginInput
|
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||||
KeepAliveInterval time.Duration
|
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||||
KeepAliveTimeout time.Duration
|
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
DialTimeout time.Duration
|
||||||
BannerCallback func(string) error
|
DialContext DialContextFunc
|
||||||
|
Proxy *ProxyConfig
|
||||||
|
Jump *LoginInput
|
||||||
|
KeepAliveInterval time.Duration
|
||||||
|
KeepAliveTimeout time.Duration
|
||||||
|
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||||
|
BannerCallback func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user