Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ad7c8b0587 | |||
| 1625997d8f | |||
| b29246a9c4 |
355
agent_forward.go
Normal file
355
agent_forward.go
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
sshagent "golang.org/x/crypto/ssh/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
return sshagent.RequestAgentForwarding(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||||
|
|
||||||
|
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return startSSHAgentForwardProxy(client, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
var probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
conn, err := dialSSHAgent(timeout)
|
||||||
|
if err != nil {
|
||||||
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||||
|
}
|
||||||
|
return conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||||
|
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||||
|
|
||||||
|
type sshAgentForwardProxy struct {
|
||||||
|
stopOnce sync.Once
|
||||||
|
stopCh chan struct{}
|
||||||
|
|
||||||
|
activeMu sync.Mutex
|
||||||
|
active map[*sshAgentForwardBridge]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) Close() error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.stopOnce.Do(func() {
|
||||||
|
close(p.stopCh)
|
||||||
|
})
|
||||||
|
p.closeActive()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentForwardBridge struct {
|
||||||
|
proxy *sshAgentForwardProxy
|
||||||
|
channel ssh.Channel
|
||||||
|
conn net.Conn
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||||
|
if s == nil {
|
||||||
|
return errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
return errors.New("ssh session is nil")
|
||||||
|
}
|
||||||
|
if err := s.ensureAgentForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := requestSSHAgentForwarding(session); err != nil {
|
||||||
|
if isSSHAgentForwardingDeniedError(err) {
|
||||||
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
|
||||||
|
if s == nil || !s.LoginInfo.ForwardSSHAgent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := s.RequestAgentForwarding(session)
|
||||||
|
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) ensureAgentForwarding() error {
|
||||||
|
if s == nil {
|
||||||
|
return errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.agentForwardMu.Lock()
|
||||||
|
defer s.agentForwardMu.Unlock()
|
||||||
|
|
||||||
|
if s.agentForwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := s.requireSSHClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := effectiveDialTimeout(s.LoginInfo)
|
||||||
|
if err := probeSSHAgentForwarding(timeout); err != nil {
|
||||||
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
|
}
|
||||||
|
if s.closing.Load() {
|
||||||
|
return errSSHClientClosing
|
||||||
|
}
|
||||||
|
closer, err := routeSSHAgentForwarding(client, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !s.canAttachAgentForwarder(client) {
|
||||||
|
if closer != nil {
|
||||||
|
_ = closer.Close()
|
||||||
|
}
|
||||||
|
return errSSHClientClosing
|
||||||
|
}
|
||||||
|
s.agentForwarder = closer
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) takeAgentForwarder() io.Closer {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.agentForwardMu.Lock()
|
||||||
|
defer s.agentForwardMu.Unlock()
|
||||||
|
|
||||||
|
closer := s.agentForwarder
|
||||||
|
s.agentForwarder = nil
|
||||||
|
return closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSSHAgentForwardingDeniedError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, errSSHAgentForwardingDenied) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
message := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(message, "forwarding request denied") ||
|
||||||
|
strings.Contains(message, "agent forwarding disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSSHAgentForwardingUnavailableError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
message := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
|
||||||
|
strings.Contains(message, "ssh-agent unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentForwardingUnavailable(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if errors.Is(err, errSSHAgentUnavailable) {
|
||||||
|
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||||
|
if channels == nil {
|
||||||
|
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-proxy.stopCh:
|
||||||
|
return
|
||||||
|
case ch, ok := <-channels:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go handleSSHAgentForwardChannel(proxy, ch, timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return proxy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, err := dialSSHAgent(timeout)
|
||||||
|
if err != nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, reqs, err := ch.Accept()
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
bridge.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) run() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer b.unregister()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(b.channel, b.conn)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(b.conn, b.channel)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) close() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.closeOnce.Do(func() {
|
||||||
|
closeWriter(b.channel)
|
||||||
|
closeWriter(b.conn)
|
||||||
|
if b.channel != nil {
|
||||||
|
_ = b.channel.Close()
|
||||||
|
}
|
||||||
|
if b.conn != nil {
|
||||||
|
_ = b.conn.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) unregister() {
|
||||||
|
if b == nil || b.proxy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.proxy.unregisterBridge(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.stopCh:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if p.active == nil {
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
}
|
||||||
|
p.active[bridge] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
delete(p.active, bridge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) closeActive() {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||||
|
for bridge := range p.active {
|
||||||
|
active = append(active, bridge)
|
||||||
|
}
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
p.activeMu.Unlock()
|
||||||
|
|
||||||
|
for _, bridge := range active {
|
||||||
|
bridge.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWriter(value any) {
|
||||||
|
type closeWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
if cw, ok := value.(closeWriter); ok {
|
||||||
|
_ = cw.CloseWrite()
|
||||||
|
}
|
||||||
|
}
|
||||||
572
agent_forward_test.go
Normal file
572
agent_forward_test.go
Normal file
@ -0,0 +1,572 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testCloser struct {
|
||||||
|
closed atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testCloser) Close() error {
|
||||||
|
c.closed.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
closed atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.closed.Add(1)
|
||||||
|
if c.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type testSSHChannel struct {
|
||||||
|
readFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
stderr bytes.Buffer
|
||||||
|
closed atomic.Int32
|
||||||
|
closeOnce sync.Once
|
||||||
|
closeCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||||
|
return &testSSHChannel{
|
||||||
|
readFunc: readFunc,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingTestSSHChannel() *testSSHChannel {
|
||||||
|
ch := newTestSSHChannel(nil)
|
||||||
|
ch.readFunc = func(p []byte) (int, error) {
|
||||||
|
<-ch.closeCh
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if c.readFunc != nil {
|
||||||
|
return c.readFunc(p)
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Close() error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closed.Add(1)
|
||||||
|
close(c.closeCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) CloseWrite() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
||||||
|
return &c.stderr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
oldCloseSSHClient := closeSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
closeSSHClient = oldCloseSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
Timeout: time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected ssh client %p", client)
|
||||||
|
}
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var probeCalls atomic.Int32
|
||||||
|
closer := &testCloser{}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
probeCalls.Add(1)
|
||||||
|
if timeout != time.Second {
|
||||||
|
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var routeCalls atomic.Int32
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
routeCalls.Add(1)
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected routed client %p", client)
|
||||||
|
}
|
||||||
|
if timeout != time.Second {
|
||||||
|
t.Fatalf("unexpected routed timeout: %v", timeout)
|
||||||
|
}
|
||||||
|
return closer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestCalls atomic.Int32
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
requestCalls.Add(1)
|
||||||
|
if session == nil {
|
||||||
|
t.Fatal("expected non-nil ssh session")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("first exec session: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("second exec session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if probeCalls.Load() != 1 {
|
||||||
|
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
||||||
|
}
|
||||||
|
if routeCalls.Load() != 1 {
|
||||||
|
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||||
|
}
|
||||||
|
if requestCalls.Load() != 2 {
|
||||||
|
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||||
|
if err := star.Close(); err != nil {
|
||||||
|
t.Fatalf("close starssh: %v", err)
|
||||||
|
}
|
||||||
|
if closer.closed.Load() != 1 {
|
||||||
|
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldRequestSessionPTY := requestSessionPTY
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
requestSessionPTY = oldRequestSessionPTY
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ptyCalls atomic.Int32
|
||||||
|
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
||||||
|
ptyCalls.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestCalls atomic.Int32
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
requestCalls.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewPTYSession(nil); err != nil {
|
||||||
|
t.Fatalf("new pty session: %v", err)
|
||||||
|
}
|
||||||
|
if ptyCalls.Load() != 1 {
|
||||||
|
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
|
||||||
|
}
|
||||||
|
if requestCalls.Load() != 1 {
|
||||||
|
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
t.Fatal("agent forwarding probe should not run when disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
t.Fatal("agent forwarding should not be requested when disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("new exec session without forwarding: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
|
}
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
t.Fatal("session request should not run when agent forwarder init fails")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected agent forwarding init error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||||
|
if !isSSHAgentForwardingUnavailableError(err) {
|
||||||
|
t.Fatalf("expected unavailable error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
|
}
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
return errors.New("forwarding request denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||||
|
if !isSSHAgentForwardingDeniedError(err) {
|
||||||
|
t.Fatalf("expected forwarding denied error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
|
}
|
||||||
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
|
return errors.New("forwarding request denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||||
|
oldNewSSHSession := newSSHSession
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
t.Cleanup(func() {
|
||||||
|
newSSHSession = oldNewSSHSession
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return &ssh.Session{}, nil
|
||||||
|
}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
|
t.Fatalf("new exec session should ignore agent setup error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||||
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
|
oldCloseSSHClient := closeSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
|
closeSSHClient = oldCloseSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{
|
||||||
|
LoginInfo: LoginInput{
|
||||||
|
ForwardSSHAgent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
|
started := make(chan struct{})
|
||||||
|
release := make(chan struct{})
|
||||||
|
closer := &testCloser{}
|
||||||
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||||
|
close(started)
|
||||||
|
<-release
|
||||||
|
return closer, nil
|
||||||
|
}
|
||||||
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- star.ensureAgentForwarding()
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
closeDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = star.Close()
|
||||||
|
close(closeDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Second)
|
||||||
|
for !star.closing.Load() {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
t.Fatal("close did not enter closing state in time")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(release)
|
||||||
|
|
||||||
|
err := <-errCh
|
||||||
|
if !errors.Is(err, errSSHClientClosing) {
|
||||||
|
t.Fatalf("expected closing error, got %v", err)
|
||||||
|
}
|
||||||
|
<-closeDone
|
||||||
|
|
||||||
|
if closer.closed.Load() != 1 {
|
||||||
|
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
|
||||||
|
}
|
||||||
|
if got := star.takeAgentForwarder(); got != nil {
|
||||||
|
t.Fatal("expected no leaked agent forwarder after close race")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
proxySSHAgentChannel(channel, tracked)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected local agent connection to be closed")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected ssh channel to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newBlockingTestSSHChannel()
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: tracked,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
t.Fatal("expected bridge registration to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
bridge.run()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
t.Fatalf("close proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("bridge did not exit after proxy close")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close local agent connection")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close ssh channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
306
forward.go
306
forward.go
@ -3,21 +3,39 @@ package starssh
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ForwardRequest struct {
|
type ForwardRequest struct {
|
||||||
|
// Keep the exported shape compatible with older positional literals:
|
||||||
|
// ForwardRequest{listenAddr, targetAddr, dialContext}.
|
||||||
|
//
|
||||||
|
// Non-default networks can be encoded with an explicit scheme-like prefix:
|
||||||
|
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
|
||||||
|
// Bare values default to the "tcp" network.
|
||||||
ListenAddr string
|
ListenAddr string
|
||||||
TargetAddr string
|
TargetAddr string
|
||||||
DialContext DialContextFunc
|
DialContext DialContextFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type normalizedForwardRequest struct {
|
||||||
|
ListenNetwork string
|
||||||
|
ListenAddr string
|
||||||
|
TargetNetwork string
|
||||||
|
TargetAddr string
|
||||||
|
DialContext DialContextFunc
|
||||||
|
}
|
||||||
|
|
||||||
type DynamicForwardRequest struct {
|
type DynamicForwardRequest struct {
|
||||||
ListenAddr string
|
ListenAddr string
|
||||||
}
|
}
|
||||||
@ -41,10 +59,16 @@ type PortForwarder struct {
|
|||||||
cleanupFns []func() error
|
cleanupFns []func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const unixForwardProbeTimeout = 200 * time.Millisecond
|
||||||
|
|
||||||
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
return client.Dial(network, address)
|
return client.Dial(network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
||||||
|
return client.Listen(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
@ -64,6 +88,90 @@ func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network strin
|
|||||||
return s.dialTCPContext(ctx, network, address, s.Close)
|
return s.dialTCPContext(ctx, network, address, s.Close)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForward(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForwardDetached(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForward(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForwardDetached(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForward(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForwardDetached(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForward(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartLocalForwardDetached(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartRemoteForward(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartRemoteForward(ForwardRequest{
|
||||||
|
ListenAddr: listenAddr,
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||||
|
return s.StartRemoteForward(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: targetAddr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||||
|
return s.StartRemoteForward(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||||
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
@ -136,21 +244,22 @@ func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error)
|
|||||||
if _, err := s.requireSSHClient(); err != nil {
|
if _, err := s.requireSSHClient(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
normalizedReq, err := normalizeForwardRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
|
||||||
return nil, errors.New("local forward listen address is empty")
|
return nil, errors.New("local forward listen address is empty")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||||
return nil, errors.New("local forward target address is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder := newPortForwarder(listener)
|
forwarder := newPortForwarder(listener)
|
||||||
|
forwarder.addCleanup(cleanup)
|
||||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||||
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||||
})
|
})
|
||||||
return forwarder, nil
|
return forwarder, nil
|
||||||
}
|
}
|
||||||
@ -159,14 +268,12 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
|
|||||||
if _, err := s.requireSSHClient(); err != nil {
|
if _, err := s.requireSSHClient(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
normalizedReq, err := normalizeForwardRequest(req)
|
||||||
return nil, errors.New("local forward listen address is empty")
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
|
||||||
return nil, errors.New("local forward target address is empty")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -174,15 +281,19 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
|
|||||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = listener.Close()
|
_ = listener.Close()
|
||||||
|
if cleanup != nil {
|
||||||
|
_ = cleanup()
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder := newPortForwarder(listener)
|
forwarder := newPortForwarder(listener)
|
||||||
|
forwarder.addCleanup(cleanup)
|
||||||
forwarder.addCleanup(func() error {
|
forwarder.addCleanup(func() error {
|
||||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||||
})
|
})
|
||||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||||
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||||
})
|
})
|
||||||
return forwarder, nil
|
return forwarder, nil
|
||||||
}
|
}
|
||||||
@ -192,19 +303,17 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
normalizedReq, err := normalizeForwardRequest(req)
|
||||||
return nil, errors.New("remote forward listen address is empty")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
|
||||||
return nil, errors.New("remote forward target address is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := client.Listen("tcp", req.ListenAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dialContext := req.DialContext
|
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dialContext := normalizedReq.DialContext
|
||||||
if dialContext == nil {
|
if dialContext == nil {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: defaultLoginTimeout,
|
Timeout: defaultLoginTimeout,
|
||||||
@ -214,7 +323,7 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
|
|||||||
|
|
||||||
forwarder := newPortForwarder(listener)
|
forwarder := newPortForwarder(listener)
|
||||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||||
return dialContext(ctx, "tcp", req.TargetAddr)
|
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||||
})
|
})
|
||||||
return forwarder, nil
|
return forwarder, nil
|
||||||
}
|
}
|
||||||
@ -239,6 +348,74 @@ func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder
|
|||||||
return forwarder, nil
|
return forwarder, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
|
||||||
|
normalized := normalizedForwardRequest{
|
||||||
|
DialContext: req.DialContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return normalized, fmt.Errorf("normalize listen address: %w", err)
|
||||||
|
}
|
||||||
|
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
|
||||||
|
if err != nil {
|
||||||
|
return normalized, fmt.Errorf("normalize target address: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(normalized.ListenAddr) == "" {
|
||||||
|
return normalized, errors.New("forward listen address is empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(normalized.TargetAddr) == "" {
|
||||||
|
return normalized, errors.New("forward target address is empty")
|
||||||
|
}
|
||||||
|
return normalized, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeForwardNetwork(network string) string {
|
||||||
|
network = strings.ToLower(strings.TrimSpace(network))
|
||||||
|
if network == "" {
|
||||||
|
return "tcp"
|
||||||
|
}
|
||||||
|
return network
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSupportedForwardNetwork(network string) bool {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6", "unix":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseForwardEndpoint(value string) (network string, address string, err error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return "tcp", "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerValue := strings.ToLower(value)
|
||||||
|
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
|
||||||
|
if strings.HasPrefix(lowerValue, prefix) {
|
||||||
|
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
|
||||||
|
address = value[len(prefix):]
|
||||||
|
if !isSupportedForwardNetwork(network) {
|
||||||
|
return "", "", fmt.Errorf("unsupported forward network %q", network)
|
||||||
|
}
|
||||||
|
return network, address, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "tcp", value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func forwardEndpoint(network string, address string) string {
|
||||||
|
network = normalizeForwardNetwork(network)
|
||||||
|
if network == "tcp" {
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
return network + "://" + address
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||||
if _, err := s.requireSSHClient(); err != nil {
|
if _, err := s.requireSSHClient(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -344,6 +521,87 @@ func (f *PortForwarder) addCleanup(fn func() error) {
|
|||||||
f.cleanupFns = append(f.cleanupFns, fn)
|
f.cleanupFns = append(f.cleanupFns, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
|
||||||
|
network = normalizeForwardNetwork(network)
|
||||||
|
if network != "unix" {
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
return listener, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeStaleUnixSocket(address); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup, err := makeUnixSocketCleanup(address)
|
||||||
|
if err != nil {
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = removeUnixSocketPath(address)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return listener, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeStaleUnixSocket(path string) error {
|
||||||
|
info, err := os.Lstat(path)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info.Mode()&os.ModeSocket == 0 {
|
||||||
|
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
|
||||||
|
if err == nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return fmt.Errorf("local unix forward path %q is already in use", path)
|
||||||
|
}
|
||||||
|
if !isStaleUnixSocketDialError(err) {
|
||||||
|
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return removeUnixSocketPath(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isStaleUnixSocketDialError(err error) bool {
|
||||||
|
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUnixSocketCleanup(path string) (func() error, error) {
|
||||||
|
info, err := os.Lstat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() error {
|
||||||
|
current, err := os.Lstat(path)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return removeUnixSocketPath(path)
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUnixSocketPath(path string) error {
|
||||||
|
err := os.Remove(path)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (f *PortForwarder) runCleanup() {
|
func (f *PortForwarder) runCleanup() {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
534
forward_test.go
534
forward_test.go
@ -2,8 +2,13 @@ package starssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -11,6 +16,59 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type stubListener struct {
|
||||||
|
addr net.Addr
|
||||||
|
acceptCh chan net.Conn
|
||||||
|
closeCh chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialRecord struct {
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStubListener(addr net.Addr) *stubListener {
|
||||||
|
return &stubListener{
|
||||||
|
addr: addr,
|
||||||
|
acceptCh: make(chan net.Conn, 1),
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *stubListener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case conn, ok := <-l.acceptCh:
|
||||||
|
if !ok {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
case <-l.closeCh:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *stubListener) Close() error {
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.closeCh)
|
||||||
|
close(l.acceptCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *stubListener) Addr() net.Addr {
|
||||||
|
return l.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *stubListener) Push(conn net.Conn) error {
|
||||||
|
select {
|
||||||
|
case <-l.closeCh:
|
||||||
|
return net.ErrClosed
|
||||||
|
case l.acceptCh <- conn:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
||||||
oldDialSSHClient := dialSSHClient
|
oldDialSSHClient := dialSSHClient
|
||||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||||
@ -63,6 +121,64 @@ func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) {
|
||||||
|
dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := normalizeForwardRequest(ForwardRequest{
|
||||||
|
"127.0.0.1:10022",
|
||||||
|
"example.internal:22",
|
||||||
|
dialer,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeForwardRequest: %v", err)
|
||||||
|
}
|
||||||
|
if req.ListenNetwork != "tcp" {
|
||||||
|
t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork)
|
||||||
|
}
|
||||||
|
if req.TargetNetwork != "tcp" {
|
||||||
|
t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork)
|
||||||
|
}
|
||||||
|
if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" {
|
||||||
|
t.Fatalf("unexpected normalized request: %+v", req)
|
||||||
|
}
|
||||||
|
if req.DialContext == nil {
|
||||||
|
t.Fatal("expected DialContext to be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) {
|
||||||
|
network, address, err := parseForwardEndpoint("tcp:22")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseForwardEndpoint: %v", err)
|
||||||
|
}
|
||||||
|
if network != "tcp" {
|
||||||
|
t.Fatalf("network=%q want tcp", network)
|
||||||
|
}
|
||||||
|
if address != "tcp:22" {
|
||||||
|
t.Fatalf("address=%q want tcp:22", address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) {
|
||||||
|
network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseForwardEndpoint unix: %v", err)
|
||||||
|
}
|
||||||
|
if network != "unix" || address != "/tmp/test-forward.sock" {
|
||||||
|
t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
network, address, err = parseForwardEndpoint("tcp6://[::1]:2222")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseForwardEndpoint tcp6: %v", err)
|
||||||
|
}
|
||||||
|
if network != "tcp6" || address != "[::1]:2222" {
|
||||||
|
t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
||||||
oldDialSSHClient := dialSSHClient
|
oldDialSSHClient := dialSSHClient
|
||||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||||
@ -132,6 +248,424 @@ func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) {
|
||||||
|
oldListenSSHClient := listenSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
listenSSHClient = oldListenSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
listener := newStubListener(&net.UnixAddr{
|
||||||
|
Name: "/run/user/0/gnupg/S.gpg-agent",
|
||||||
|
Net: "unix",
|
||||||
|
})
|
||||||
|
|
||||||
|
var listenedNetwork string
|
||||||
|
var listenedAddr string
|
||||||
|
listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected ssh client %p", client)
|
||||||
|
}
|
||||||
|
listenedNetwork = network
|
||||||
|
listenedAddr = address
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var targetNetwork string
|
||||||
|
var targetAddr string
|
||||||
|
forwarder, err := star.StartRemoteForward(ForwardRequest{
|
||||||
|
ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"),
|
||||||
|
TargetAddr: "127.0.0.1:4321",
|
||||||
|
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
targetNetwork = network
|
||||||
|
targetAddr = address
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
go echoForwardPipe(serverConn)
|
||||||
|
return clientConn, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start remote unix forward: %v", err)
|
||||||
|
}
|
||||||
|
defer forwarder.Close()
|
||||||
|
|
||||||
|
srcPeer, forwardedConn := net.Pipe()
|
||||||
|
defer srcPeer.Close()
|
||||||
|
if err := listener.Push(forwardedConn); err != nil {
|
||||||
|
t.Fatalf("push forwarded connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("unix-forward")
|
||||||
|
done := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
reply := make([]byte, len(payload))
|
||||||
|
_, _ = io.ReadFull(srcPeer, reply)
|
||||||
|
done <- reply
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err := srcPeer.Write(payload); err != nil {
|
||||||
|
t.Fatalf("write source payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case reply := <-done:
|
||||||
|
if string(reply) != string(payload) {
|
||||||
|
t.Fatalf("unexpected remote unix forward reply: %q", string(reply))
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("remote unix forward did not relay payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" {
|
||||||
|
t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr)
|
||||||
|
}
|
||||||
|
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
|
||||||
|
t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDialSSHClient := dialSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialSSHClient = oldDialSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
var targetNetwork string
|
||||||
|
var targetAddr string
|
||||||
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected ssh client %p", client)
|
||||||
|
}
|
||||||
|
targetNetwork = network
|
||||||
|
targetAddr = address
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
go echoForwardPipe(serverConn)
|
||||||
|
return clientConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "forward.sock")
|
||||||
|
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start local unix forward: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
closeErr := forwarder.Close()
|
||||||
|
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||||
|
t.Fatalf("close local unix forward: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("unix", socketPath, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial unix forward listener: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
|
||||||
|
payload := []byte("unix-local-forward")
|
||||||
|
if _, err := conn.Write(payload); err != nil {
|
||||||
|
t.Fatalf("write unix forward payload: %v", err)
|
||||||
|
}
|
||||||
|
reply := make([]byte, len(payload))
|
||||||
|
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||||
|
t.Fatalf("read unix forward reply: %v", err)
|
||||||
|
}
|
||||||
|
if string(reply) != string(payload) {
|
||||||
|
t.Fatalf("unexpected unix forward reply: %q", string(reply))
|
||||||
|
}
|
||||||
|
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
|
||||||
|
t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDialSSHClient := dialSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialSSHClient = oldDialSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
go echoForwardPipe(serverConn)
|
||||||
|
return clientConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "cleanup.sock")
|
||||||
|
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start local unix forward: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Lstat(socketPath); err != nil {
|
||||||
|
t.Fatalf("socket should exist while forward is running: %v", err)
|
||||||
|
}
|
||||||
|
if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Fatalf("close local unix forward: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) {
|
||||||
|
t.Fatalf("socket path should be removed on close, got err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDialSSHClient := dialSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialSSHClient = oldDialSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
|
serverConn, clientConn := net.Pipe()
|
||||||
|
go echoForwardPipe(serverConn)
|
||||||
|
return clientConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "stale.sock")
|
||||||
|
staleListener, err := net.ListenUnix("unix", &net.UnixAddr{
|
||||||
|
Name: socketPath,
|
||||||
|
Net: "unix",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create stale unix socket: %v", err)
|
||||||
|
}
|
||||||
|
staleListener.SetUnlinkOnClose(false)
|
||||||
|
if err := staleListener.Close(); err != nil {
|
||||||
|
t.Fatalf("close stale unix socket listener: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Lstat(socketPath); err != nil {
|
||||||
|
t.Fatalf("expected stale unix socket path to remain after close: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start local unix forward on stale socket path: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
closeErr := forwarder.Close()
|
||||||
|
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||||
|
t.Fatalf("close local unix forward: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply := make([]byte, len("stale-reuse"))
|
||||||
|
conn, err := net.DialTimeout("unix", socketPath, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial reused unix forward listener: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
if _, err := conn.Write([]byte("stale-reuse")); err != nil {
|
||||||
|
t.Fatalf("write reused unix forward payload: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||||
|
t.Fatalf("read reused unix forward reply: %v", err)
|
||||||
|
}
|
||||||
|
if string(reply) != "stale-reuse" {
|
||||||
|
t.Fatalf("unexpected reply on reused unix forward: %q", string(reply))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDialSSHClient := dialSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialSSHClient = oldDialSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
targetSocketPath := filepath.Join(t.TempDir(), "target.sock")
|
||||||
|
targetListener, err := net.Listen("unix", targetSocketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen target unix socket: %v", err)
|
||||||
|
}
|
||||||
|
defer targetListener.Close()
|
||||||
|
|
||||||
|
done := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, acceptErr := targetListener.Accept()
|
||||||
|
if acceptErr != nil {
|
||||||
|
done <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(buf[:n])
|
||||||
|
done <- buf[:n]
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialRecordCh := make(chan dialRecord, 1)
|
||||||
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected ssh client %p", client)
|
||||||
|
}
|
||||||
|
dialRecordCh <- dialRecord{network: network, addr: address}
|
||||||
|
var dialer net.Dialer
|
||||||
|
return dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
listenSocketPath := filepath.Join(t.TempDir(), "listen.sock")
|
||||||
|
forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start local unix-to-unix forward: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
closeErr := forwarder.Close()
|
||||||
|
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||||
|
t.Fatalf("close local unix-to-unix forward: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("unix", listenSocketPath, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial unix-to-unix listener: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
|
||||||
|
payload := []byte("unix-to-unix")
|
||||||
|
if _, err := conn.Write(payload); err != nil {
|
||||||
|
t.Fatalf("write unix-to-unix payload: %v", err)
|
||||||
|
}
|
||||||
|
reply := make([]byte, len(payload))
|
||||||
|
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||||
|
t.Fatalf("read unix-to-unix reply: %v", err)
|
||||||
|
}
|
||||||
|
if string(reply) != string(payload) {
|
||||||
|
t.Fatalf("unexpected unix-to-unix reply: %q", string(reply))
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-done:
|
||||||
|
if string(got) != string(payload) {
|
||||||
|
t.Fatalf("unexpected payload seen by target unix socket: %q", string(got))
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("target unix socket did not receive forwarded payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-dialRecordCh:
|
||||||
|
if got.network != "unix" || got.addr != targetSocketPath {
|
||||||
|
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not observe unix target dial")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDialSSHClient := dialSSHClient
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialSSHClient = oldDialSSHClient
|
||||||
|
})
|
||||||
|
|
||||||
|
baseClient := &ssh.Client{}
|
||||||
|
star := &StarSSH{}
|
||||||
|
star.setTransport(baseClient, nil)
|
||||||
|
|
||||||
|
targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock")
|
||||||
|
targetListener, err := net.Listen("unix", targetSocketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen target unix socket: %v", err)
|
||||||
|
}
|
||||||
|
defer targetListener.Close()
|
||||||
|
|
||||||
|
done := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, acceptErr := targetListener.Accept()
|
||||||
|
if acceptErr != nil {
|
||||||
|
done <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(buf[:n])
|
||||||
|
done <- buf[:n]
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialRecordCh := make(chan dialRecord, 1)
|
||||||
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||||
|
if client != baseClient {
|
||||||
|
t.Fatalf("unexpected ssh client %p", client)
|
||||||
|
}
|
||||||
|
dialRecordCh <- dialRecord{network: network, addr: address}
|
||||||
|
var dialer net.Dialer
|
||||||
|
return dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start local tcp-to-unix forward: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
closeErr := forwarder.Close()
|
||||||
|
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||||
|
t.Fatalf("close local tcp-to-unix forward: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix"))
|
||||||
|
if string(reply) != "tcp-to-unix" {
|
||||||
|
t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply))
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-done:
|
||||||
|
if string(got) != "tcp-to-unix" {
|
||||||
|
t.Fatalf("unexpected payload seen by unix target: %q", string(got))
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("unix target did not receive forwarded tcp payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-dialRecordCh:
|
||||||
|
if got.network != "unix" || got.addr != targetSocketPath {
|
||||||
|
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not observe unix target dial")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func echoForwardPipe(conn net.Conn) {
|
func echoForwardPipe(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
|
|||||||
34
login.go
34
login.go
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||||
|
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||||
|
|
||||||
var defaultAuthOrder = []AuthMethodKind{
|
var defaultAuthOrder = []AuthMethodKind{
|
||||||
AuthMethodSSHAgent,
|
AuthMethodSSHAgent,
|
||||||
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
return nil, ErrHostKeyCallbackRequired
|
return nil, ErrHostKeyCallbackRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
authTimeout := effectiveLoginTimeout(info)
|
||||||
|
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshInfo := &StarSSH{
|
sshInfo := &StarSSH{
|
||||||
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
clientConfig := &ssh.ClientConfig{
|
clientConfig := &ssh.ClientConfig{
|
||||||
User: info.User,
|
User: info.User,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Timeout: info.Timeout,
|
Timeout: authTimeout,
|
||||||
HostKeyCallback: hostKeyCallback,
|
HostKeyCallback: hostKeyCallback,
|
||||||
BannerCallback: bannerCallback,
|
BannerCallback: bannerCallback,
|
||||||
}
|
}
|
||||||
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return sshInfo, err
|
return sshInfo, err
|
||||||
}
|
}
|
||||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
|
||||||
defer restoreDeadline()
|
defer restoreDeadline()
|
||||||
|
|
||||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||||
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
|
|||||||
Addr: host,
|
Addr: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
|
DialTimeout: timeout,
|
||||||
User: user,
|
User: user,
|
||||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||||
}
|
}
|
||||||
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
|
|||||||
if info.Port <= 0 {
|
if info.Port <= 0 {
|
||||||
info.Port = defaultSSHPort
|
info.Port = defaultSSHPort
|
||||||
}
|
}
|
||||||
if info.Timeout <= 0 {
|
|
||||||
info.Timeout = defaultLoginTimeout
|
|
||||||
}
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||||
|
if info.Timeout <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return info.Timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||||
|
switch {
|
||||||
|
case info.DialTimeout < 0:
|
||||||
|
return 0
|
||||||
|
case info.DialTimeout > 0:
|
||||||
|
return info.DialTimeout
|
||||||
|
case info.Timeout > 0:
|
||||||
|
return info.Timeout
|
||||||
|
default:
|
||||||
|
return defaultLoginTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
|||||||
if info.DisableSSHAgent {
|
if info.DisableSSHAgent {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
agentErr = err
|
agentErr = err
|
||||||
continue
|
continue
|
||||||
|
|||||||
99
login_timeout_test.go
Normal file
99
login_timeout_test.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||||
|
info := normalizeLoginInput(LoginInput{})
|
||||||
|
if info.Port != defaultSSHPort {
|
||||||
|
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||||
|
}
|
||||||
|
if info.Timeout != 0 {
|
||||||
|
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||||
|
}
|
||||||
|
if info.DialTimeout != 0 {
|
||||||
|
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||||
|
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||||
|
}
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||||
|
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveDialTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info LoginInput
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default fallback",
|
||||||
|
info: LoginInput{},
|
||||||
|
want: defaultLoginTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reuse timeout when dial timeout omitted",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second},
|
||||||
|
want: 9 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit dial timeout wins",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||||
|
want: 3 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative dial timeout disables default dial deadline",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||||
|
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
captured := time.Duration(-2)
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||||||
|
captured = timeout
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info := LoginInput{
|
||||||
|
Timeout: 0,
|
||||||
|
DialTimeout: 11 * time.Second,
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}
|
||||||
|
auth, cleanup, err := buildAuthMethods(info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethods: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||||
|
}
|
||||||
|
if captured != 11*time.Second {
|
||||||
|
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
34
session.go
34
session.go
@ -9,6 +9,14 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
|
return client.NewSession()
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
||||||
|
return session.RequestPty(config.Term, config.Rows, config.Columns, config.Modes)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StarSSH) Close() error {
|
func (s *StarSSH) Close() error {
|
||||||
return s.closeTransport(true)
|
return s.closeTransport(true)
|
||||||
}
|
}
|
||||||
@ -22,7 +30,15 @@ func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewExecSession(client)
|
session, err := NewExecSession(client)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.maybeRequestAgentForwarding(session); err != nil {
|
||||||
|
_ = session.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
|
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
|
||||||
@ -30,7 +46,15 @@ func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewPTYSession(client, config)
|
session, err := NewPTYSession(client, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.maybeRequestAgentForwarding(session); err != nil {
|
||||||
|
_ = session.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
||||||
@ -41,7 +65,7 @@ func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
|
|||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, errors.New("ssh client is nil")
|
return nil, errors.New("ssh client is nil")
|
||||||
}
|
}
|
||||||
return client.NewSession()
|
return newSSHSession(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
||||||
@ -53,13 +77,13 @@ func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, er
|
|||||||
return nil, errors.New("ssh client is nil")
|
return nil, errors.New("ssh client is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := newSSHSession(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := normalizeTerminalConfig(config)
|
cfg := normalizeTerminalConfig(config)
|
||||||
if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil {
|
if err := requestSessionPTY(session, cfg); err != nil {
|
||||||
_ = session.Close()
|
_ = session.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
31
state.go
31
state.go
@ -6,6 +6,8 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errSSHClientClosing = errors.New("ssh client is closing")
|
||||||
|
|
||||||
type sshClientRequester interface {
|
type sshClientRequester interface {
|
||||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
||||||
Close() error
|
Close() error
|
||||||
@ -29,10 +31,19 @@ func (s *StarSSH) snapshotSSHClient() *ssh.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
|
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
if s.closing.Load() {
|
||||||
|
return nil, errSSHClientClosing
|
||||||
|
}
|
||||||
client := s.snapshotSSHClient()
|
client := s.snapshotSSHClient()
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, errors.New("ssh client is nil")
|
return nil, errors.New("ssh client is nil")
|
||||||
}
|
}
|
||||||
|
if s.closing.Load() {
|
||||||
|
return nil, errSSHClientClosing
|
||||||
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,6 +57,7 @@ func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
|
|||||||
s.Client = client
|
s.Client = client
|
||||||
s.upstream = upstream
|
s.upstream = upstream
|
||||||
s.online = client != nil
|
s.online = client != nil
|
||||||
|
s.closing.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
|
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
|
||||||
@ -84,7 +96,9 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.closing.Store(true)
|
||||||
_ = s.closeReusableSFTPClient()
|
_ = s.closeReusableSFTPClient()
|
||||||
|
agentForwarder := s.takeAgentForwarder()
|
||||||
|
|
||||||
client, upstream := s.detachTransport()
|
client, upstream := s.detachTransport()
|
||||||
stop, done := s.takeKeepaliveHandles()
|
stop, done := s.takeKeepaliveHandles()
|
||||||
@ -93,8 +107,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var closeErr error
|
var closeErr error
|
||||||
|
if agentForwarder != nil {
|
||||||
|
closeErr = normalizeAlreadyClosedError(agentForwarder.Close())
|
||||||
|
}
|
||||||
if client != nil {
|
if client != nil {
|
||||||
closeErr = normalizeAlreadyClosedError(closeSSHClient(client))
|
if err := normalizeAlreadyClosedError(closeSSHClient(client)); closeErr == nil {
|
||||||
|
closeErr = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if waitKeepalive && done != nil {
|
if waitKeepalive && done != nil {
|
||||||
<-done
|
<-done
|
||||||
@ -104,3 +123,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
|
|||||||
}
|
}
|
||||||
return closeErr
|
return closeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StarSSH) canAttachAgentForwarder(client *ssh.Client) bool {
|
||||||
|
if s == nil || client == nil || s.closing.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.stateMu.RLock()
|
||||||
|
defer s.stateMu.RUnlock()
|
||||||
|
return !s.closing.Load() && s.Client == client
|
||||||
|
}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: info.Timeout,
|
Timeout: effectiveDialTimeout(info),
|
||||||
}
|
}
|
||||||
return dialer.DialContext
|
return dialer.DialContext
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialContext := resolveDialContext(info)
|
dialContext := resolveDialContext(info)
|
||||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||||
if proxyConfig != nil {
|
if proxyConfig != nil {
|
||||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||||
}
|
}
|
||||||
|
|||||||
11
types.go
11
types.go
@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
@ -70,9 +71,12 @@ type StarSSH struct {
|
|||||||
upstream *StarSSH
|
upstream *StarSSH
|
||||||
sftpClient *sftp.Client
|
sftpClient *sftp.Client
|
||||||
sftpMu sync.Mutex
|
sftpMu sync.Mutex
|
||||||
|
agentForwardMu sync.Mutex
|
||||||
|
agentForwarder io.Closer
|
||||||
keepaliveMu sync.Mutex
|
keepaliveMu sync.Mutex
|
||||||
keepaliveStop chan struct{}
|
keepaliveStop chan struct{}
|
||||||
keepaliveDone chan struct{}
|
keepaliveDone chan struct{}
|
||||||
|
closing atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoginInput struct {
|
type LoginInput struct {
|
||||||
@ -86,10 +90,17 @@ type LoginInput struct {
|
|||||||
Prikey string
|
Prikey string
|
||||||
PrikeyPwd string
|
PrikeyPwd string
|
||||||
DisableSSHAgent bool
|
DisableSSHAgent bool
|
||||||
|
ForwardSSHAgent bool
|
||||||
AuthOrder []AuthMethodKind
|
AuthOrder []AuthMethodKind
|
||||||
Addr string
|
Addr string
|
||||||
Port int
|
Port int
|
||||||
|
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||||
|
// already been established. Zero means no authentication timeout.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||||
|
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||||
|
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||||
|
DialTimeout time.Duration
|
||||||
DialContext DialContextFunc
|
DialContext DialContextFunc
|
||||||
Proxy *ProxyConfig
|
Proxy *ProxyConfig
|
||||||
Jump *LoginInput
|
Jump *LoginInput
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user