Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
672a111ec1
|
|||
|
0c23e7d4bf
|
|||
|
ad7c8b0587
|
|||
|
1625997d8f
|
+316
-13
@@ -4,7 +4,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
@@ -15,24 +17,57 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|||||||
return sshagent.RequestAgentForwarding(session)
|
return sshagent.RequestAgentForwarding(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||||
return sshagent.ForwardToAgent(client, keyring)
|
|
||||||
|
var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
|
return startSSHAgentForwardProxy(client, timeouts)
|
||||||
}
|
}
|
||||||
|
|
||||||
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
conn, err := dialSSHAgent(timeout)
|
conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
}
|
}
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||||
}
|
}
|
||||||
return sshagent.NewClient(conn), conn, nil
|
return conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||||
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||||
|
|
||||||
|
type sshAgentForwardProxy struct {
|
||||||
|
stopOnce sync.Once
|
||||||
|
stopCh chan struct{}
|
||||||
|
|
||||||
|
activeMu sync.Mutex
|
||||||
|
active map[*sshAgentForwardBridge]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) Close() error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.stopOnce.Do(func() {
|
||||||
|
close(p.stopCh)
|
||||||
|
})
|
||||||
|
p.closeActive()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentForwardBridge struct {
|
||||||
|
proxy *sshAgentForwardProxy
|
||||||
|
channel ssh.Channel
|
||||||
|
conn net.Conn
|
||||||
|
idleTimeout time.Duration
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
signalOnce sync.Once
|
||||||
|
done chan struct{}
|
||||||
|
activity chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return errors.New("ssh client is nil")
|
return errors.New("ssh client is nil")
|
||||||
@@ -80,20 +115,21 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
|
timeouts := effectiveSSHAgentTimeouts(s.LoginInfo)
|
||||||
if err != nil {
|
if err := probeSSHAgentForwarding(timeouts); err != nil {
|
||||||
return wrapSSHAgentForwardingUnavailable(err)
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
}
|
}
|
||||||
if s.closing.Load() {
|
if s.closing.Load() {
|
||||||
_ = closer.Close()
|
|
||||||
return errSSHClientClosing
|
return errSSHClientClosing
|
||||||
}
|
}
|
||||||
if err := routeSSHAgentForwarding(client, keyring); err != nil {
|
closer, err := routeSSHAgentForwarding(client, timeouts)
|
||||||
_ = closer.Close()
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !s.canAttachAgentForwarder(client) {
|
if !s.canAttachAgentForwarder(client) {
|
||||||
_ = closer.Close()
|
if closer != nil {
|
||||||
|
_ = closer.Close()
|
||||||
|
}
|
||||||
return errSSHClientClosing
|
return errSSHClientClosing
|
||||||
}
|
}
|
||||||
s.agentForwarder = closer
|
s.agentForwarder = closer
|
||||||
@@ -149,3 +185,270 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, errors.New("ssh client is nil")
|
||||||
|
}
|
||||||
|
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||||
|
if channels == nil {
|
||||||
|
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-proxy.stopCh:
|
||||||
|
return
|
||||||
|
case ch, ok := <-channels:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go handleSSHAgentForwardChannel(proxy, ch, timeouts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return proxy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts)
|
||||||
|
if err != nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channel, reqs, err := ch.Accept()
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
idleTimeout: timeouts.Forward,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
bridge.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
channel: channel,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
bridge.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) run() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.ensureSignals()
|
||||||
|
stopWatchdog := b.startIdleWatchdog()
|
||||||
|
defer stopWatchdog()
|
||||||
|
defer b.unregister()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(
|
||||||
|
sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch},
|
||||||
|
sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch},
|
||||||
|
)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(
|
||||||
|
sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch},
|
||||||
|
sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch},
|
||||||
|
)
|
||||||
|
b.close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) close() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.closeOnce.Do(func() {
|
||||||
|
b.ensureSignals()
|
||||||
|
close(b.done)
|
||||||
|
closeWriter(b.channel)
|
||||||
|
closeWriter(b.conn)
|
||||||
|
if b.channel != nil {
|
||||||
|
_ = b.channel.Close()
|
||||||
|
}
|
||||||
|
if b.conn != nil {
|
||||||
|
_ = b.conn.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) ensureSignals() {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.signalOnce.Do(func() {
|
||||||
|
b.done = make(chan struct{})
|
||||||
|
b.activity = make(chan struct{}, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) startIdleWatchdog() func() {
|
||||||
|
if b == nil || b.idleTimeout <= 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
b.ensureSignals()
|
||||||
|
timer := time.NewTimer(b.idleTimeout)
|
||||||
|
stopped := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer timer.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
b.close()
|
||||||
|
return
|
||||||
|
case <-b.activity:
|
||||||
|
resetTimer(timer, b.idleTimeout)
|
||||||
|
case <-b.done:
|
||||||
|
return
|
||||||
|
case <-stopped:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
close(stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) touch() {
|
||||||
|
if b == nil || b.idleTimeout <= 0 || b.activity == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case b.activity <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentForwardActivityReader struct {
|
||||||
|
io.Reader
|
||||||
|
touch func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) {
|
||||||
|
n, err := r.Reader.Read(p)
|
||||||
|
if n > 0 && r.touch != nil {
|
||||||
|
r.touch()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentForwardActivityWriter struct {
|
||||||
|
io.Writer
|
||||||
|
touch func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) {
|
||||||
|
n, err := w.Writer.Write(p)
|
||||||
|
if n > 0 && w.touch != nil {
|
||||||
|
w.touch()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetTimer(timer *time.Timer, timeout time.Duration) {
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.Reset(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sshAgentForwardBridge) unregister() {
|
||||||
|
if b == nil || b.proxy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.proxy.unregisterBridge(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.stopCh:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if p.active == nil {
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
}
|
||||||
|
p.active[bridge] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||||
|
if p == nil || bridge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
defer p.activeMu.Unlock()
|
||||||
|
|
||||||
|
delete(p.active, bridge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sshAgentForwardProxy) closeActive() {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.activeMu.Lock()
|
||||||
|
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||||
|
for bridge := range p.active {
|
||||||
|
active = append(active, bridge)
|
||||||
|
}
|
||||||
|
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||||
|
p.activeMu.Unlock()
|
||||||
|
|
||||||
|
for _, bridge := range active {
|
||||||
|
bridge.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWriter(value any) {
|
||||||
|
type closeWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
if cw, ok := value.(closeWriter); ok {
|
||||||
|
_ = cw.CloseWrite()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+308
-60
@@ -1,14 +1,16 @@
|
|||||||
package starssh
|
package starssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
sshagent "golang.org/x/crypto/ssh/agent"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type testCloser struct {
|
type testCloser struct {
|
||||||
@@ -20,15 +22,116 @@ func (c *testCloser) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
closed atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.closed.Add(1)
|
||||||
|
if c.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type testSSHChannel struct {
|
||||||
|
readFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
stderr bytes.Buffer
|
||||||
|
closed atomic.Int32
|
||||||
|
closeOnce sync.Once
|
||||||
|
closeCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testNewChannel struct {
|
||||||
|
channel ssh.Channel
|
||||||
|
accepted atomic.Bool
|
||||||
|
rejected atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
c.accepted.Store(true)
|
||||||
|
requests := make(chan *ssh.Request)
|
||||||
|
close(requests)
|
||||||
|
return c.channel, requests, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
|
||||||
|
c.rejected.Store(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testNewChannel) ChannelType() string {
|
||||||
|
return sshAgentChannelType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testNewChannel) ExtraData() []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||||
|
return &testSSHChannel{
|
||||||
|
readFunc: readFunc,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingTestSSHChannel() *testSSHChannel {
|
||||||
|
ch := newTestSSHChannel(nil)
|
||||||
|
ch.readFunc = func(p []byte) (int, error) {
|
||||||
|
<-ch.closeCh
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if c.readFunc != nil {
|
||||||
|
return c.readFunc(p)
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Close() error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closed.Add(1)
|
||||||
|
close(c.closeCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) CloseWrite() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
||||||
|
return &c.stderr
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
oldCloseSSHClient := closeSSHClient
|
oldCloseSSHClient := closeSSHClient
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
closeSSHClient = oldCloseSSHClient
|
closeSSHClient = oldCloseSSHClient
|
||||||
@@ -37,8 +140,10 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
baseClient := &ssh.Client{}
|
baseClient := &ssh.Client{}
|
||||||
star := &StarSSH{
|
star := &StarSSH{
|
||||||
LoginInfo: LoginInput{
|
LoginInfo: LoginInput{
|
||||||
ForwardSSHAgent: true,
|
ForwardSSHAgent: true,
|
||||||
Timeout: time.Second,
|
Timeout: time.Second,
|
||||||
|
SSHAgentTimeout: 3 * time.Second,
|
||||||
|
SSHAgentForwardTimeout: 4 * time.Second,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
star.setTransport(baseClient, nil)
|
star.setTransport(baseClient, nil)
|
||||||
@@ -50,26 +155,38 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var agentInitCalls atomic.Int32
|
var probeCalls atomic.Int32
|
||||||
closer := &testCloser{}
|
closer := &testCloser{}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
agentInitCalls.Add(1)
|
probeCalls.Add(1)
|
||||||
if timeout != time.Second {
|
if timeouts.Dial != time.Second {
|
||||||
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
|
||||||
}
|
}
|
||||||
return sshagent.NewKeyring(), closer, nil
|
if timeouts.Operation != 3*time.Second {
|
||||||
|
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
|
||||||
|
}
|
||||||
|
if timeouts.Forward != 4*time.Second {
|
||||||
|
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var routeCalls atomic.Int32
|
var routeCalls atomic.Int32
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
routeCalls.Add(1)
|
routeCalls.Add(1)
|
||||||
if client != baseClient {
|
if client != baseClient {
|
||||||
t.Fatalf("unexpected routed client %p", client)
|
t.Fatalf("unexpected routed client %p", client)
|
||||||
}
|
}
|
||||||
if keyring == nil {
|
if timeouts.Dial != time.Second {
|
||||||
t.Fatal("expected non-nil forwarded agent keyring")
|
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
|
||||||
}
|
}
|
||||||
return nil
|
if timeouts.Operation != 3*time.Second {
|
||||||
|
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
|
||||||
|
}
|
||||||
|
if timeouts.Forward != 4*time.Second {
|
||||||
|
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
|
||||||
|
}
|
||||||
|
return closer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestCalls atomic.Int32
|
var requestCalls atomic.Int32
|
||||||
@@ -88,8 +205,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
t.Fatalf("second exec session: %v", err)
|
t.Fatalf("second exec session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if agentInitCalls.Load() != 1 {
|
if probeCalls.Load() != 1 {
|
||||||
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
|
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
||||||
}
|
}
|
||||||
if routeCalls.Load() != 1 {
|
if routeCalls.Load() != 1 {
|
||||||
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||||
@@ -110,13 +227,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|||||||
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldRequestSessionPTY := requestSessionPTY
|
oldRequestSessionPTY := requestSessionPTY
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
requestSessionPTY = oldRequestSessionPTY
|
requestSessionPTY = oldRequestSessionPTY
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@@ -138,10 +255,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
|
|
||||||
var requestCalls atomic.Int32
|
var requestCalls atomic.Int32
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
@@ -162,11 +281,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -176,9 +295,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
t.Fatal("agent forwarder should not initialize when disabled")
|
t.Fatal("agent forwarding probe should not run when disabled")
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
t.Fatal("agent forwarding should not be requested when disabled")
|
t.Fatal("agent forwarding should not be requested when disabled")
|
||||||
@@ -191,18 +310,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
}
|
}
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
t.Fatal("session request should not run when agent forwarder init fails")
|
t.Fatal("session request should not run when agent forwarder init fails")
|
||||||
@@ -216,16 +335,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||||
@@ -235,11 +354,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@@ -247,10 +366,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
|||||||
star := &StarSSH{}
|
star := &StarSSH{}
|
||||||
star.setTransport(&ssh.Client{}, nil)
|
star.setTransport(&ssh.Client{}, nil)
|
||||||
|
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
return errors.New("forwarding request denied")
|
return errors.New("forwarding request denied")
|
||||||
}
|
}
|
||||||
@@ -263,12 +384,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||||
})
|
})
|
||||||
@@ -283,10 +404,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
|
return &testCloser{}, nil
|
||||||
}
|
}
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
|
||||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||||
return errors.New("forwarding request denied")
|
return errors.New("forwarding request denied")
|
||||||
}
|
}
|
||||||
@@ -298,10 +421,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{
|
star := &StarSSH{
|
||||||
@@ -314,8 +437,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := star.NewExecSession(); err != nil {
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
@@ -325,10 +448,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||||
oldNewSSHSession := newSSHSession
|
oldNewSSHSession := newSSHSession
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHSession = oldNewSSHSession
|
newSSHSession = oldNewSSHSession
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
})
|
})
|
||||||
|
|
||||||
star := &StarSSH{
|
star := &StarSSH{
|
||||||
@@ -341,8 +464,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
|||||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||||
return &ssh.Session{}, nil
|
return &ssh.Session{}, nil
|
||||||
}
|
}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := star.NewExecSession(); err != nil {
|
if _, err := star.NewExecSession(); err != nil {
|
||||||
@@ -351,11 +474,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||||
oldCloseSSHClient := closeSSHClient
|
oldCloseSSHClient := closeSSHClient
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||||
closeSSHClient = oldCloseSSHClient
|
closeSSHClient = oldCloseSSHClient
|
||||||
})
|
})
|
||||||
@@ -370,13 +493,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
|||||||
started := make(chan struct{})
|
started := make(chan struct{})
|
||||||
release := make(chan struct{})
|
release := make(chan struct{})
|
||||||
closer := &testCloser{}
|
closer := &testCloser{}
|
||||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||||
close(started)
|
close(started)
|
||||||
<-release
|
<-release
|
||||||
return sshagent.NewKeyring(), closer, nil
|
return closer, nil
|
||||||
}
|
|
||||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||||
|
|
||||||
@@ -415,3 +538,128 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
|||||||
t.Fatal("expected no leaked agent forwarder after close race")
|
t.Fatal("expected no leaked agent forwarder after close race")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
proxySSHAgentChannel(channel, tracked)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected local agent connection to be closed")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected ssh channel to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
channel := newBlockingTestSSHChannel()
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
bridge := &sshAgentForwardBridge{
|
||||||
|
proxy: proxy,
|
||||||
|
channel: channel,
|
||||||
|
conn: tracked,
|
||||||
|
}
|
||||||
|
if !proxy.registerBridge(bridge) {
|
||||||
|
t.Fatal("expected bridge registration to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
bridge.run()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
t.Fatalf("close proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("bridge did not exit after proxy close")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tracked.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close local agent connection")
|
||||||
|
}
|
||||||
|
if channel.closed.Load() == 0 {
|
||||||
|
t.Fatal("expected proxy close to close ssh channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
|
||||||
|
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
||||||
|
})
|
||||||
|
|
||||||
|
agentConn, peerConn := net.Pipe()
|
||||||
|
defer peerConn.Close()
|
||||||
|
tracked := &trackedConn{Conn: agentConn}
|
||||||
|
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return tracked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
channel := newBlockingTestSSHChannel()
|
||||||
|
newChannel := &testNewChannel{
|
||||||
|
channel: channel,
|
||||||
|
}
|
||||||
|
proxy := &sshAgentForwardProxy{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||||
|
}
|
||||||
|
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
|
||||||
|
Endpoint: "/tmp/agent.sock",
|
||||||
|
Forward: 20 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !newChannel.accepted.Load() {
|
||||||
|
t.Fatal("expected channel to be accepted")
|
||||||
|
}
|
||||||
|
|
||||||
|
waitUntil(t, time.Second, func() bool {
|
||||||
|
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
|
||||||
|
}, "forwarded agent bridge did not close both sides after idle timeout")
|
||||||
|
|
||||||
|
waitUntil(t, time.Second, func() bool {
|
||||||
|
proxy.activeMu.Lock()
|
||||||
|
defer proxy.activeMu.Unlock()
|
||||||
|
return len(proxy.active) == 0
|
||||||
|
}, "forwarded agent bridge did not unregister after idle timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if condition() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal(message)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,25 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/crypto/ssh/agent"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
|
||||||
|
|
||||||
var defaultAuthOrder = []AuthMethodKind{
|
|
||||||
AuthMethodSSHAgent,
|
|
||||||
AuthMethodPrivateKey,
|
|
||||||
AuthMethodPassword,
|
|
||||||
AuthMethodKeyboardInteractive,
|
|
||||||
}
|
|
||||||
|
|
||||||
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
return nil
|
return nil
|
||||||
@@ -42,14 +31,39 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
return nil, ErrHostKeyCallbackRequired
|
return nil, ErrHostKeyCallbackRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
authTimeout := effectiveLoginTimeout(info)
|
||||||
|
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldRetrySSHAgentAuth(info, order) {
|
||||||
|
agentAttempt := newSSHAgentAuthAttempt()
|
||||||
|
for {
|
||||||
|
agentAttempt.begin()
|
||||||
|
sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt)
|
||||||
|
if err == nil {
|
||||||
|
return sshInfo, nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return sshInfo, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginOnceWithContext(loginCtx, info, authTimeout, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) {
|
||||||
sshInfo := &StarSSH{
|
sshInfo := &StarSSH{
|
||||||
LoginInfo: info,
|
LoginInfo: info,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, authCleanup, err := buildAuthMethods(info)
|
auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -76,7 +90,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
clientConfig := &ssh.ClientConfig{
|
clientConfig := &ssh.ClientConfig{
|
||||||
User: info.User,
|
User: info.User,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Timeout: info.Timeout,
|
Timeout: authTimeout,
|
||||||
HostKeyCallback: hostKeyCallback,
|
HostKeyCallback: hostKeyCallback,
|
||||||
BannerCallback: bannerCallback,
|
BannerCallback: bannerCallback,
|
||||||
}
|
}
|
||||||
@@ -89,11 +103,11 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
targetAddr := joinHostPort(info.Addr, info.Port)
|
targetAddr := joinHostPort(info.Addr, info.Port)
|
||||||
rawConn, upstream, err := dialTargetConn(loginCtx, info)
|
rawConn, upstream, err := dialTargetConn(ctx, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sshInfo, err
|
return sshInfo, err
|
||||||
}
|
}
|
||||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout)
|
||||||
defer restoreDeadline()
|
defer restoreDeadline()
|
||||||
|
|
||||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||||
@@ -130,6 +144,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
|
|||||||
Addr: host,
|
Addr: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
|
DialTimeout: timeout,
|
||||||
User: user,
|
User: user,
|
||||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||||
}
|
}
|
||||||
@@ -154,209 +169,25 @@ func normalizeLoginInput(info LoginInput) LoginInput {
|
|||||||
if info.Port <= 0 {
|
if info.Port <= 0 {
|
||||||
info.Port = defaultSSHPort
|
info.Port = defaultSSHPort
|
||||||
}
|
}
|
||||||
if info.Timeout <= 0 {
|
|
||||||
info.Timeout = defaultLoginTimeout
|
|
||||||
}
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
if info.Timeout <= 0 {
|
||||||
if err != nil {
|
return 0
|
||||||
return nil, nil, err
|
|
||||||
}
|
}
|
||||||
|
return info.Timeout
|
||||||
auth := make([]ssh.AuthMethod, 0, len(order))
|
|
||||||
var agentErr error
|
|
||||||
var cleanupFuncs []func()
|
|
||||||
|
|
||||||
for _, methodKind := range order {
|
|
||||||
switch methodKind {
|
|
||||||
case AuthMethodPrivateKey:
|
|
||||||
method, err := buildPrivateKeyAuthMethod(info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if method != nil {
|
|
||||||
auth = append(auth, method)
|
|
||||||
}
|
|
||||||
case AuthMethodPassword:
|
|
||||||
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
|
|
||||||
if method != nil {
|
|
||||||
auth = append(auth, method)
|
|
||||||
}
|
|
||||||
case AuthMethodKeyboardInteractive:
|
|
||||||
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
|
|
||||||
if method != nil {
|
|
||||||
auth = append(auth, method)
|
|
||||||
}
|
|
||||||
case AuthMethodSSHAgent:
|
|
||||||
if info.DisableSSHAgent {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
|
||||||
if err != nil {
|
|
||||||
agentErr = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if agentMethod != nil {
|
|
||||||
auth = append(auth, agentMethod)
|
|
||||||
}
|
|
||||||
if cleanup != nil {
|
|
||||||
cleanupFuncs = append(cleanupFuncs, cleanup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(auth) == 0 {
|
|
||||||
if agentErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
|
|
||||||
}
|
|
||||||
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
return auth, composeCleanup(cleanupFuncs...), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
|
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||||
if len(order) == 0 {
|
switch {
|
||||||
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
|
case info.DialTimeout < 0:
|
||||||
}
|
return 0
|
||||||
|
case info.DialTimeout > 0:
|
||||||
normalized := make([]AuthMethodKind, 0, len(order))
|
return info.DialTimeout
|
||||||
seen := make(map[AuthMethodKind]struct{}, len(order))
|
case info.Timeout > 0:
|
||||||
for _, raw := range order {
|
return info.Timeout
|
||||||
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
|
|
||||||
if kind == "" {
|
|
||||||
return nil, errors.New("auth order contains an empty auth method")
|
|
||||||
}
|
|
||||||
if !isSupportedAuthMethodKind(kind) {
|
|
||||||
return nil, fmt.Errorf("unsupported auth method %q", raw)
|
|
||||||
}
|
|
||||||
if _, exists := seen[kind]; exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[kind] = struct{}{}
|
|
||||||
normalized = append(normalized, kind)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(normalized) == 0 {
|
|
||||||
return nil, errors.New("auth order is empty")
|
|
||||||
}
|
|
||||||
return normalized, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
|
|
||||||
switch kind {
|
|
||||||
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
|
|
||||||
return true
|
|
||||||
default:
|
default:
|
||||||
return false
|
return defaultLoginTimeout
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
|
|
||||||
if strings.TrimSpace(info.Prikey) == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
pemBytes := []byte(info.Prikey)
|
|
||||||
if info.PrikeyPwd == "" {
|
|
||||||
signer, err := ssh.ParsePrivateKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.PublicKeys(signer), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.PublicKeys(signer), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
|
|
||||||
if password != "" {
|
|
||||||
return ssh.Password(password)
|
|
||||||
}
|
|
||||||
if callback != nil {
|
|
||||||
return ssh.PasswordCallback(callback)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildKeyboardInteractiveAuthMethod(
|
|
||||||
password string,
|
|
||||||
passwordCallback func() (string, error),
|
|
||||||
challenge ssh.KeyboardInteractiveChallenge,
|
|
||||||
) ssh.AuthMethod {
|
|
||||||
if challenge != nil {
|
|
||||||
return ssh.KeyboardInteractive(challenge)
|
|
||||||
}
|
|
||||||
if password == "" && passwordCallback == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
|
||||||
if len(questions) == 0 {
|
|
||||||
return []string{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
answer := password
|
|
||||||
if answer == "" {
|
|
||||||
var err error
|
|
||||||
answer, err = passwordCallback()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
answers := make([]string, len(questions))
|
|
||||||
for i := range questions {
|
|
||||||
answers[i] = answer
|
|
||||||
}
|
|
||||||
return answers, nil
|
|
||||||
}
|
|
||||||
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
|
||||||
conn, err := dialSSHAgent(timeout)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errSSHAgentUnavailable) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
signers, err := agent.NewClient(conn).Signers()
|
|
||||||
if err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if len(signers) == 0 {
|
|
||||||
_ = conn.Close()
|
|
||||||
return nil, nil, errors.New("ssh-agent has no loaded keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.PublicKeys(signers...), func() {
|
|
||||||
_ = conn.Close()
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func composeCleanup(funcs ...func()) func() {
|
|
||||||
if len(funcs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return func() {
|
|
||||||
for i := len(funcs) - 1; i >= 0; i-- {
|
|
||||||
if funcs[i] != nil {
|
|
||||||
funcs[i]()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,763 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
sshagent "golang.org/x/crypto/ssh/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||||
|
info := normalizeLoginInput(LoginInput{})
|
||||||
|
if info.Port != defaultSSHPort {
|
||||||
|
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||||
|
}
|
||||||
|
if info.Timeout != 0 {
|
||||||
|
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||||
|
}
|
||||||
|
if info.DialTimeout != 0 {
|
||||||
|
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||||
|
}
|
||||||
|
if info.SSHAgentTimeout != 0 {
|
||||||
|
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
|
||||||
|
}
|
||||||
|
if info.SSHAgentForwardTimeout != 0 {
|
||||||
|
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||||
|
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||||
|
}
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||||
|
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveDialTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info LoginInput
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default fallback",
|
||||||
|
info: LoginInput{},
|
||||||
|
want: defaultLoginTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reuse timeout when dial timeout omitted",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second},
|
||||||
|
want: 9 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit dial timeout wins",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||||
|
want: 3 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative dial timeout disables default dial deadline",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||||
|
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveSSHAgentTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info LoginInput
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default fallback without auth timeout",
|
||||||
|
info: LoginInput{},
|
||||||
|
want: defaultSSHAgentTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth timeout does not cap default",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second},
|
||||||
|
want: defaultSSHAgentTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit agent timeout wins",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
|
||||||
|
want: 90 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative agent timeout disables operation deadline",
|
||||||
|
info: LoginInput{SSHAgentTimeout: -1},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
|
||||||
|
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
|
||||||
|
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
|
||||||
|
t.Fatalf("zero forward timeout should stay zero, got %v", got)
|
||||||
|
}
|
||||||
|
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
|
||||||
|
t.Fatalf("expected explicit forward timeout, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||||
|
captured = timeouts
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info := LoginInput{
|
||||||
|
Timeout: 0,
|
||||||
|
DialTimeout: 11 * time.Second,
|
||||||
|
SSHAgentTimeout: 90 * time.Second,
|
||||||
|
SSHAgentForwardTimeout: 4 * time.Second,
|
||||||
|
IdentityAgent: "/tmp/custom-agent.sock",
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}
|
||||||
|
auth, cleanup, err := buildAuthMethods(info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethods: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||||
|
}
|
||||||
|
if captured.Dial != 11*time.Second {
|
||||||
|
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
|
||||||
|
}
|
||||||
|
if captured.Operation != 90*time.Second {
|
||||||
|
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
|
||||||
|
}
|
||||||
|
if captured.Forward != 4*time.Second {
|
||||||
|
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
|
||||||
|
}
|
||||||
|
if captured.Endpoint != "/tmp/custom-agent.sock" {
|
||||||
|
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, cleanup, err := buildAuthMethods(LoginInput{
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethods: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("auth methods=%d, want 1", len(auth))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
|
||||||
|
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
|
||||||
|
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
|
||||||
|
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
|
||||||
|
}
|
||||||
|
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
|
||||||
|
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
|
||||||
|
}
|
||||||
|
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
|
||||||
|
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||||
|
if timeouts.SignFailure == nil {
|
||||||
|
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
|
||||||
|
}
|
||||||
|
if timeouts.SkipFingerprints != nil {
|
||||||
|
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
|
||||||
|
}
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||||
|
Password: "secret",
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
|
||||||
|
}, newSSHAgentAuthAttempt())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 2 {
|
||||||
|
t.Fatalf("auth methods=%d want 2", len(auth))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
|
||||||
|
attempt := newSSHAgentAuthAttempt()
|
||||||
|
attempt.skipFingerprint("SHA256:test")
|
||||||
|
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||||
|
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
|
||||||
|
}
|
||||||
|
attempt.begin()
|
||||||
|
if err := checkSSHAgentRetryPending(attempt); err != nil {
|
||||||
|
t.Fatalf("retry should reset on next attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
|
||||||
|
signer := mustGenerateTestSigner(t)
|
||||||
|
attempt := newSSHAgentAuthAttempt()
|
||||||
|
callback := privateKeySignersCallback(signer, attempt)
|
||||||
|
|
||||||
|
signers, err := callback()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("private key callback before retry: %v", err)
|
||||||
|
}
|
||||||
|
if len(signers) != 1 || signers[0] != signer {
|
||||||
|
t.Fatalf("private key callback returned %#v, want original signer", signers)
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.skipFingerprint("SHA256:test")
|
||||||
|
signers, err = callback()
|
||||||
|
if !errors.Is(err, errRetrySSHAgentAuth) {
|
||||||
|
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
|
||||||
|
}
|
||||||
|
if signers != nil {
|
||||||
|
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.begin()
|
||||||
|
signers, err = callback()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("private key callback after retry reset: %v", err)
|
||||||
|
}
|
||||||
|
if len(signers) != 1 || signers[0] != signer {
|
||||||
|
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
|
||||||
|
firstSigner := mustGenerateTestSigner(t)
|
||||||
|
secondSigner := mustGenerateTestSigner(t)
|
||||||
|
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
|
||||||
|
|
||||||
|
attempt := newSSHAgentAuthAttempt()
|
||||||
|
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
||||||
|
SignFailure: attempt.recordSignFailure,
|
||||||
|
SkipFingerprints: attempt.skipSnapshot(),
|
||||||
|
})
|
||||||
|
if len(firstMethods) != 2 {
|
||||||
|
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
|
||||||
|
}
|
||||||
|
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||||
|
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
|
||||||
|
}
|
||||||
|
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
||||||
|
SignFailure: attempt.recordSignFailure,
|
||||||
|
SkipFingerprints: attempt.skipSnapshot(),
|
||||||
|
})
|
||||||
|
if len(secondMethods) != 1 {
|
||||||
|
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
|
||||||
|
}
|
||||||
|
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatalf("second auth method did not skip failed first key")
|
||||||
|
}
|
||||||
|
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second signer Sign: %v", err)
|
||||||
|
}
|
||||||
|
if signature == nil {
|
||||||
|
t.Fatal("second signer returned nil signature")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
|
||||||
|
firstSigner := mustGenerateTestSigner(t)
|
||||||
|
secondSigner := mustGenerateTestSigner(t)
|
||||||
|
wantErr := errors.New("first agent key cannot sign")
|
||||||
|
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
|
||||||
|
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
var buildCalls int
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||||
|
buildCalls++
|
||||||
|
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
|
||||||
|
if buildCalls == 1 {
|
||||||
|
if len(filteredSigners) != 2 {
|
||||||
|
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
|
||||||
|
}
|
||||||
|
return ssh.PublicKeys(filteredSigners...), nil, nil
|
||||||
|
}
|
||||||
|
if len(filteredSigners) != 1 {
|
||||||
|
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
|
||||||
|
}
|
||||||
|
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatal("retry build did not skip failed signer")
|
||||||
|
}
|
||||||
|
return ssh.PublicKeys(filteredSigners...), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt := newSSHAgentAuthAttempt()
|
||||||
|
attempt.begin()
|
||||||
|
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}, attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("first auth methods=%d want 1", len(auth))
|
||||||
|
}
|
||||||
|
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("raw failing signer err=%v", err)
|
||||||
|
}
|
||||||
|
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
|
||||||
|
SignFailure: attempt.recordSignFailure,
|
||||||
|
})[0]
|
||||||
|
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||||
|
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.begin()
|
||||||
|
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}, attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("retry auth methods=%d want 1", len(auth))
|
||||||
|
}
|
||||||
|
if buildCalls != 2 {
|
||||||
|
t.Fatalf("build calls=%d want 2", buildCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
|
||||||
|
plainSigner := mustGenerateTestSigner(t)
|
||||||
|
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
|
||||||
|
|
||||||
|
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
|
||||||
|
if len(ordered) != 2 {
|
||||||
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||||
|
}
|
||||||
|
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
|
||||||
|
plainSigner := mustGenerateTestSigner(t)
|
||||||
|
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
|
||||||
|
|
||||||
|
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
|
||||||
|
if len(ordered) != 2 {
|
||||||
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||||
|
}
|
||||||
|
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
|
||||||
|
firstSigner := mustGenerateTestSigner(t)
|
||||||
|
secondSigner := mustGenerateTestSigner(t)
|
||||||
|
|
||||||
|
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
|
||||||
|
if len(ordered) != 2 {
|
||||||
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||||
|
}
|
||||||
|
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatalf("first signer changed order without hints")
|
||||||
|
}
|
||||||
|
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||||
|
t.Fatalf("second signer changed order without hints")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
|
||||||
|
signer := mustGenerateTestSigner(t)
|
||||||
|
wantErr := errors.New("agent refused operation")
|
||||||
|
var debugCalls int
|
||||||
|
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
|
||||||
|
Resolved: resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: "/tmp/debug-agent.sock",
|
||||||
|
Source: "identity-agent",
|
||||||
|
Network: "unix",
|
||||||
|
},
|
||||||
|
Debug: func(event SSHAgentDebugEvent) {
|
||||||
|
debugCalls++
|
||||||
|
if event.Step != "auth" || event.Phase != "sign" {
|
||||||
|
t.Fatalf("unexpected debug event: %+v", event)
|
||||||
|
}
|
||||||
|
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
|
||||||
|
t.Fatalf("unexpected endpoint details: %+v", event)
|
||||||
|
}
|
||||||
|
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
|
||||||
|
t.Fatalf("unexpected sign status: %+v", event)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("Sign err=%v want original signer error", err)
|
||||||
|
}
|
||||||
|
if debugCalls != 1 {
|
||||||
|
t.Fatalf("debug calls=%d want 1", debugCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
|
||||||
|
signer := mustGenerateRSATestSigner(t)
|
||||||
|
spy := &testAlgorithmSpySigner{Signer: signer}
|
||||||
|
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignWithAlgorithm: %v", err)
|
||||||
|
}
|
||||||
|
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
|
||||||
|
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
|
||||||
|
}
|
||||||
|
if signature.Format != ssh.KeyAlgoRSASHA256 {
|
||||||
|
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
|
||||||
|
signer := mustGenerateRSATestSigner(t)
|
||||||
|
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSignerWithAlgorithms: %v", err)
|
||||||
|
}
|
||||||
|
spy := &testMultiAlgorithmSpySigner{
|
||||||
|
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
|
||||||
|
}
|
||||||
|
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignWithAlgorithm: %v", err)
|
||||||
|
}
|
||||||
|
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
|
||||||
|
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
|
||||||
|
}
|
||||||
|
if signature.Format != ssh.KeyAlgoRSA {
|
||||||
|
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineSpyConn struct {
|
||||||
|
net.Conn
|
||||||
|
mu sync.Mutex
|
||||||
|
deadlines []time.Time
|
||||||
|
readErr error
|
||||||
|
writeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type testFailingSigner struct {
|
||||||
|
ssh.Signer
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type testAlgorithmSpySigner struct {
|
||||||
|
ssh.Signer
|
||||||
|
lastAlgorithm string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||||
|
s.lastAlgorithm = algorithm
|
||||||
|
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
type testMultiAlgorithmSpySigner struct {
|
||||||
|
*testAlgorithmSpySigner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
|
||||||
|
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
|
||||||
|
return multiAlgorithmSigner.Algorithms()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
|
||||||
|
t.Helper()
|
||||||
|
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate test private key: %v", err)
|
||||||
|
}
|
||||||
|
signer, err := ssh.NewSignerFromKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new test signer: %v", err)
|
||||||
|
}
|
||||||
|
return signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
|
||||||
|
t.Helper()
|
||||||
|
baseSigner := mustGenerateTestSigner(t)
|
||||||
|
publicKey := baseSigner.PublicKey()
|
||||||
|
return &commentedTestSigner{
|
||||||
|
Signer: baseSigner,
|
||||||
|
publicKey: &sshagent.Key{
|
||||||
|
Format: publicKey.Type(),
|
||||||
|
Blob: publicKey.Marshal(),
|
||||||
|
Comment: comment,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type commentedTestSigner struct {
|
||||||
|
ssh.Signer
|
||||||
|
publicKey ssh.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
|
||||||
|
return s.publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
|
||||||
|
t.Helper()
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate rsa test private key: %v", err)
|
||||||
|
}
|
||||||
|
signer, err := ssh.NewSignerFromKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new rsa test signer: %v", err)
|
||||||
|
}
|
||||||
|
return signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.deadlines = append(c.deadlines, deadline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineSpyConn) deadlineCount() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return len(c.deadlines)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineSpyConn) firstDeadline() time.Time {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.deadlines[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
|
||||||
|
if c.readErr != nil {
|
||||||
|
return 0, c.readErr
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
|
||||||
|
if c.writeErr != nil {
|
||||||
|
return 0, c.writeErr
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
|
||||||
|
spy := &deadlineSpyConn{readErr: io.EOF}
|
||||||
|
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("Read err=%v", err)
|
||||||
|
}
|
||||||
|
if spy.deadlineCount() != 1 {
|
||||||
|
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
||||||
|
}
|
||||||
|
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
|
||||||
|
t.Fatalf("deadline=%v should be in the future", firstDeadline)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
|
||||||
|
spy := &deadlineSpyConn{}
|
||||||
|
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
||||||
|
if _, err := conn.Write([]byte("x")); err != nil {
|
||||||
|
t.Fatalf("Write err=%v", err)
|
||||||
|
}
|
||||||
|
if spy.deadlineCount() != 1 {
|
||||||
|
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
|
||||||
|
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
||||||
|
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
||||||
|
}
|
||||||
|
if resolved.Endpoint != "/tmp/identity-agent.sock" {
|
||||||
|
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
||||||
|
}
|
||||||
|
if resolved.Source != "identity-agent" {
|
||||||
|
t.Fatalf("source=%q", resolved.Source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
|
||||||
|
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
||||||
|
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
||||||
|
}
|
||||||
|
if resolved.Endpoint != "/tmp/env-agent.sock" {
|
||||||
|
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
||||||
|
}
|
||||||
|
if resolved.Source != "SSH_AUTH_SOCK" {
|
||||||
|
t.Fatalf("source=%q", resolved.Source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
|
||||||
|
server, client := net.Pipe()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
||||||
|
})
|
||||||
|
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
|
||||||
|
Operation: 20 * time.Millisecond,
|
||||||
|
Endpoint: "/tmp/hung-agent.sock",
|
||||||
|
})
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrSSHAgentTimeout) {
|
||||||
|
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
|
||||||
|
socketPath := tempUnixSocketPath(t)
|
||||||
|
listener, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen unix: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var events []SSHAgentDebugEvent
|
||||||
|
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
|
||||||
|
Dial: time.Second,
|
||||||
|
Operation: time.Second,
|
||||||
|
Endpoint: socketPath,
|
||||||
|
Debug: func(event SSHAgentDebugEvent) {
|
||||||
|
events = append(events, event)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
<-done
|
||||||
|
|
||||||
|
if len(events) == 0 {
|
||||||
|
t.Fatal("expected debug events")
|
||||||
|
}
|
||||||
|
if events[0].Step != "auth" || events[0].Phase != "dial" {
|
||||||
|
t.Fatalf("unexpected first event: %+v", events[0])
|
||||||
|
}
|
||||||
|
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
|
||||||
|
t.Fatalf("unexpected endpoint event: %+v", events[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tempUnixSocketPath(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
path := t.TempDir() + "/agent.sock"
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Remove(path)
|
||||||
|
})
|
||||||
|
return path
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
@@ -36,6 +37,19 @@ type SFTPTransferOptions struct {
|
|||||||
VerifySize *bool
|
VerifySize *bool
|
||||||
VerifyChecksum *bool
|
VerifyChecksum *bool
|
||||||
TempSuffix string
|
TempSuffix string
|
||||||
|
Client SFTPClientOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFTPClientOptions controls the underlying SFTP protocol client.
|
||||||
|
//
|
||||||
|
// These options only apply when StarSSH creates the SFTP client internally.
|
||||||
|
// They are intentionally separate from BufferSize, which only controls the
|
||||||
|
// local copy buffer used by transfer progress reporting.
|
||||||
|
type SFTPClientOptions struct {
|
||||||
|
MaxPacketSize int
|
||||||
|
MaxConcurrentRequestsPerFile int
|
||||||
|
ConcurrentReads *bool
|
||||||
|
ConcurrentWrites *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolvedSFTPTransferOptions struct {
|
type resolvedSFTPTransferOptions struct {
|
||||||
@@ -48,6 +62,38 @@ type resolvedSFTPTransferOptions struct {
|
|||||||
VerifySize bool
|
VerifySize bool
|
||||||
VerifyChecksum bool
|
VerifyChecksum bool
|
||||||
TempSuffix string
|
TempSuffix string
|
||||||
|
Client resolvedSFTPClientOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolvedSFTPClientOptions struct {
|
||||||
|
MaxPacketSize int
|
||||||
|
MaxConcurrentRequestsPerFile int
|
||||||
|
ConcurrentReads *bool
|
||||||
|
ConcurrentWrites *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type sftpConcurrentReaderFrom interface {
|
||||||
|
ReadFromWithConcurrency(io.Reader, int) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sftpUploadProgressReader struct {
|
||||||
|
ctx context.Context
|
||||||
|
reader io.Reader
|
||||||
|
total int64
|
||||||
|
progress func(float64)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
copied int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type sftpDownloadProgressWriter struct {
|
||||||
|
ctx context.Context
|
||||||
|
writer io.Writer
|
||||||
|
total int64
|
||||||
|
progress func(float64)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
copied int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type SFTPErrorCategory string
|
type SFTPErrorCategory string
|
||||||
@@ -106,6 +152,7 @@ var (
|
|||||||
sftpVerifyLocalSizeFunc = verifyLocalSize
|
sftpVerifyLocalSizeFunc = verifyLocalSize
|
||||||
sftpLocalFileSHA256Func = localFileSHA256
|
sftpLocalFileSHA256Func = localFileSHA256
|
||||||
sftpRemoteFileSHA256Func = remoteFileSHA256
|
sftpRemoteFileSHA256Func = remoteFileSHA256
|
||||||
|
sftpNewClientFunc = sftp.NewClient
|
||||||
)
|
)
|
||||||
|
|
||||||
func DefaultSFTPTransferOptions() SFTPTransferOptions {
|
func DefaultSFTPTransferOptions() SFTPTransferOptions {
|
||||||
@@ -121,6 +168,16 @@ func DefaultSFTPTransferOptions() SFTPTransferOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ThroughputSFTPTransferOptions() SFTPTransferOptions {
|
||||||
|
opts := DefaultSFTPTransferOptions()
|
||||||
|
opts.Client = SFTPClientOptions{
|
||||||
|
ConcurrentReads: SFTPBool(true),
|
||||||
|
ConcurrentWrites: SFTPBool(true),
|
||||||
|
MaxConcurrentRequestsPerFile: 32,
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
func SFTPBool(value bool) *bool {
|
func SFTPBool(value bool) *bool {
|
||||||
return &value
|
return &value
|
||||||
}
|
}
|
||||||
@@ -316,51 +373,40 @@ func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath st
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions {
|
func normalizeSFTPTransferOptions(options *SFTPTransferOptions) (resolvedSFTPTransferOptions, error) {
|
||||||
opts := DefaultSFTPTransferOptions()
|
opts := DefaultSFTPTransferOptions()
|
||||||
if options == nil {
|
if options != nil {
|
||||||
return resolvedSFTPTransferOptions{
|
if options.BufferSize > 0 {
|
||||||
BufferSize: opts.BufferSize,
|
opts.BufferSize = options.BufferSize
|
||||||
Progress: opts.Progress,
|
|
||||||
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
|
|
||||||
RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff),
|
|
||||||
AtomicUpload: derefSFTPBool(opts.AtomicUpload, true),
|
|
||||||
AtomicDownload: derefSFTPBool(opts.AtomicDownload, true),
|
|
||||||
VerifySize: derefSFTPBool(opts.VerifySize, true),
|
|
||||||
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
|
|
||||||
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
|
|
||||||
}
|
}
|
||||||
|
if options.Progress != nil {
|
||||||
|
opts.Progress = options.Progress
|
||||||
|
}
|
||||||
|
if options.RetryCount != nil {
|
||||||
|
opts.RetryCount = options.RetryCount
|
||||||
|
}
|
||||||
|
if options.RetryInitialBackoff != nil {
|
||||||
|
opts.RetryInitialBackoff = options.RetryInitialBackoff
|
||||||
|
}
|
||||||
|
if options.AtomicUpload != nil {
|
||||||
|
opts.AtomicUpload = options.AtomicUpload
|
||||||
|
}
|
||||||
|
if options.AtomicDownload != nil {
|
||||||
|
opts.AtomicDownload = options.AtomicDownload
|
||||||
|
}
|
||||||
|
if options.VerifySize != nil {
|
||||||
|
opts.VerifySize = options.VerifySize
|
||||||
|
}
|
||||||
|
if options.VerifyChecksum != nil {
|
||||||
|
opts.VerifyChecksum = options.VerifyChecksum
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(options.TempSuffix) != "" {
|
||||||
|
opts.TempSuffix = options.TempSuffix
|
||||||
|
}
|
||||||
|
opts.Client = options.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.BufferSize > 0 {
|
resolved := resolvedSFTPTransferOptions{
|
||||||
opts.BufferSize = options.BufferSize
|
|
||||||
}
|
|
||||||
if options.Progress != nil {
|
|
||||||
opts.Progress = options.Progress
|
|
||||||
}
|
|
||||||
if options.RetryCount != nil {
|
|
||||||
opts.RetryCount = options.RetryCount
|
|
||||||
}
|
|
||||||
if options.RetryInitialBackoff != nil {
|
|
||||||
opts.RetryInitialBackoff = options.RetryInitialBackoff
|
|
||||||
}
|
|
||||||
if options.AtomicUpload != nil {
|
|
||||||
opts.AtomicUpload = options.AtomicUpload
|
|
||||||
}
|
|
||||||
if options.AtomicDownload != nil {
|
|
||||||
opts.AtomicDownload = options.AtomicDownload
|
|
||||||
}
|
|
||||||
if options.VerifySize != nil {
|
|
||||||
opts.VerifySize = options.VerifySize
|
|
||||||
}
|
|
||||||
if options.VerifyChecksum != nil {
|
|
||||||
opts.VerifyChecksum = options.VerifyChecksum
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(options.TempSuffix) != "" {
|
|
||||||
opts.TempSuffix = options.TempSuffix
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolvedSFTPTransferOptions{
|
|
||||||
BufferSize: opts.BufferSize,
|
BufferSize: opts.BufferSize,
|
||||||
Progress: opts.Progress,
|
Progress: opts.Progress,
|
||||||
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
|
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
|
||||||
@@ -371,6 +417,11 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran
|
|||||||
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
|
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
|
||||||
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
|
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
|
||||||
}
|
}
|
||||||
|
resolved.Client = normalizeSFTPClientOptions(opts.Client)
|
||||||
|
if err := validateResolvedSFTPTransferOptions(resolved); err != nil {
|
||||||
|
return resolvedSFTPTransferOptions{}, err
|
||||||
|
}
|
||||||
|
return resolved, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func derefSFTPBool(value *bool, fallback bool) bool {
|
func derefSFTPBool(value *bool, fallback bool) bool {
|
||||||
@@ -409,13 +460,56 @@ func normalizeSFTPRetryCount(value int) int {
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeSFTPClientOptions(options SFTPClientOptions) resolvedSFTPClientOptions {
|
||||||
|
return resolvedSFTPClientOptions{
|
||||||
|
MaxPacketSize: options.MaxPacketSize,
|
||||||
|
MaxConcurrentRequestsPerFile: options.MaxConcurrentRequestsPerFile,
|
||||||
|
ConcurrentReads: options.ConcurrentReads,
|
||||||
|
ConcurrentWrites: options.ConcurrentWrites,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateResolvedSFTPTransferOptions(opts resolvedSFTPTransferOptions) error {
|
||||||
|
return validateResolvedSFTPClientOptions(opts.Client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateResolvedSFTPClientOptions(opts resolvedSFTPClientOptions) error {
|
||||||
|
if opts.MaxPacketSize < 0 {
|
||||||
|
return errors.New("sftp max packet size must not be negative")
|
||||||
|
}
|
||||||
|
if opts.MaxConcurrentRequestsPerFile < 0 {
|
||||||
|
return errors.New("sftp max concurrent requests per file must not be negative")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rejectExternalSFTPClientOptions(opts resolvedSFTPClientOptions) error {
|
||||||
|
if opts.MaxPacketSize != 0 ||
|
||||||
|
opts.MaxConcurrentRequestsPerFile != 0 ||
|
||||||
|
opts.ConcurrentReads != nil ||
|
||||||
|
opts.ConcurrentWrites != nil {
|
||||||
|
return errors.New("sftp client options require StarSSH-managed SFTP client")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSFTPUploadOptions(opts resolvedSFTPTransferOptions) error {
|
||||||
|
if derefSFTPBool(opts.Client.ConcurrentWrites, false) && !opts.AtomicUpload {
|
||||||
|
return errors.New("sftp concurrent writes require atomic upload")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error {
|
func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error {
|
||||||
if err := ensureContext(ctx); err != nil {
|
if err := ensureContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts := normalizeSFTPTransferOptions(nil)
|
opts, err := normalizeSFTPTransferOptions(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error {
|
return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error {
|
||||||
return star.withIsolatedSFTPClient(ctx, fn)
|
return star.withIsolatedSFTPClient(ctx, opts.Client, fn)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,23 +517,27 @@ func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func(
|
|||||||
if err := ensureContext(ctx); err != nil {
|
if err := ensureContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return star.withIsolatedSFTPClient(ctx, fn)
|
return star.withIsolatedSFTPClient(ctx, resolvedSFTPClientOptions{}, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) {
|
func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) {
|
||||||
|
return star.createSFTPClientWithOptions(resolvedSFTPClientOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (star *StarSSH) createSFTPClientWithOptions(options resolvedSFTPClientOptions) (*sftp.Client, error) {
|
||||||
client, err := star.requireSSHClient()
|
client, err := star.requireSSHClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sftp.NewClient(client)
|
return sftpNewClientFunc(client, buildSFTPClientOptions(options)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error {
|
func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, options resolvedSFTPClientOptions, fn func(*sftp.Client) error) error {
|
||||||
if err := ensureContext(ctx); err != nil {
|
if err := ensureContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := star.CreateSftpClient()
|
client, err := star.createSFTPClientWithOptions(options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -448,6 +546,23 @@ func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.C
|
|||||||
return fn(client)
|
return fn(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildSFTPClientOptions(options resolvedSFTPClientOptions) []sftp.ClientOption {
|
||||||
|
clientOptions := make([]sftp.ClientOption, 0, 4)
|
||||||
|
if options.MaxPacketSize > 0 {
|
||||||
|
clientOptions = append(clientOptions, sftp.MaxPacketChecked(options.MaxPacketSize))
|
||||||
|
}
|
||||||
|
if options.MaxConcurrentRequestsPerFile > 0 {
|
||||||
|
clientOptions = append(clientOptions, sftp.MaxConcurrentRequestsPerFile(options.MaxConcurrentRequestsPerFile))
|
||||||
|
}
|
||||||
|
if options.ConcurrentReads != nil {
|
||||||
|
clientOptions = append(clientOptions, sftp.UseConcurrentReads(*options.ConcurrentReads))
|
||||||
|
}
|
||||||
|
if options.ConcurrentWrites != nil {
|
||||||
|
clientOptions = append(clientOptions, sftp.UseConcurrentWrites(*options.ConcurrentWrites))
|
||||||
|
}
|
||||||
|
return clientOptions
|
||||||
|
}
|
||||||
|
|
||||||
func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) {
|
func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) {
|
||||||
if star == nil {
|
if star == nil {
|
||||||
return nil, errors.New("ssh client is nil")
|
return nil, errors.New("ssh client is nil")
|
||||||
@@ -526,7 +641,7 @@ func (star *StarSSH) runSFTPWithRetry(
|
|||||||
fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error,
|
fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error,
|
||||||
) error {
|
) error {
|
||||||
return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error {
|
return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error {
|
||||||
return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error {
|
return star.withIsolatedSFTPClient(ctx, opts.Client, func(client *sftp.Client) error {
|
||||||
return fn(ctx, client, opts)
|
return fn(ctx, client, opts)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -537,7 +652,13 @@ func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error {
|
func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||||
return transferOutContext(ctx, client, localFilePath, remotePath, opts)
|
return transferOutContext(ctx, client, localFilePath, remotePath, opts)
|
||||||
})
|
})
|
||||||
@@ -548,7 +669,16 @@ func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error {
|
return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error {
|
||||||
return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts)
|
return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts)
|
||||||
})
|
})
|
||||||
@@ -559,7 +689,13 @@ func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error {
|
func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||||
return transferOutByteContext(ctx, client, localData, remotePath, opts)
|
return transferOutByteContext(ctx, client, localData, remotePath, opts)
|
||||||
})
|
})
|
||||||
@@ -570,7 +706,16 @@ func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error {
|
return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error {
|
||||||
return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts)
|
return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts)
|
||||||
})
|
})
|
||||||
@@ -595,10 +740,13 @@ func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) {
|
func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var data []byte
|
var data []byte
|
||||||
err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
err = star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||||
out, runErr := transferInByteContext(ctx, client, remotePath, opts)
|
out, runErr := transferInByteContext(ctx, client, remotePath, opts)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
return runErr
|
return runErr
|
||||||
@@ -617,10 +765,16 @@ func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) {
|
func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var data []byte
|
var data []byte
|
||||||
err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error {
|
err = executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error {
|
||||||
out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts)
|
out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
return runErr
|
return runErr
|
||||||
@@ -639,7 +793,10 @@ func (star *StarSSH) SftpTransferIn(src, dst string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error {
|
func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||||
return transferInContext(ctx, client, src, dst, opts)
|
return transferInContext(ctx, client, src, dst, opts)
|
||||||
})
|
})
|
||||||
@@ -650,7 +807,13 @@ func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||||
opts := normalizeSFTPTransferOptions(options)
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error {
|
return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error {
|
||||||
return transferInContext(ctx, sftpClient, src, dst, opts)
|
return transferInContext(ctx, sftpClient, src, dst, opts)
|
||||||
})
|
})
|
||||||
@@ -712,7 +875,7 @@ func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFileP
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
if _, err := copyUploadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||||
_ = dstFile.Close()
|
_ = dstFile.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -792,7 +955,7 @@ func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localD
|
|||||||
}
|
}
|
||||||
|
|
||||||
reader := bytes.NewReader(localData)
|
reader := bytes.NewReader(localData)
|
||||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil {
|
if _, err := copyUploadWithProgressContext(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress, opts); err != nil {
|
||||||
_ = dstFile.Close()
|
_ = dstFile.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -877,7 +1040,7 @@ func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst st
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
if _, err := copyDownloadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||||
_ = dstFile.Close()
|
_ = dstFile.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -948,7 +1111,7 @@ func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remoteP
|
|||||||
}
|
}
|
||||||
|
|
||||||
var out bytes.Buffer
|
var out bytes.Buffer
|
||||||
if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
if _, err := copyDownloadWithProgressContext(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1149,6 +1312,151 @@ func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader,
|
|||||||
return copied, nil
|
return copied, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyUploadWithProgressContext(
|
||||||
|
ctx context.Context,
|
||||||
|
dst io.Writer,
|
||||||
|
src io.Reader,
|
||||||
|
bufSize int,
|
||||||
|
total int64,
|
||||||
|
progress func(float64),
|
||||||
|
opts resolvedSFTPTransferOptions,
|
||||||
|
) (int64, error) {
|
||||||
|
if !derefSFTPBool(opts.Client.ConcurrentWrites, false) {
|
||||||
|
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
readerFrom, ok := dst.(sftpConcurrentReaderFrom)
|
||||||
|
if !ok {
|
||||||
|
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ensureContext(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if progress != nil && total > 0 {
|
||||||
|
progress(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedSrc := &sftpUploadProgressReader{
|
||||||
|
ctx: ctx,
|
||||||
|
reader: src,
|
||||||
|
total: total,
|
||||||
|
progress: progress,
|
||||||
|
}
|
||||||
|
written, err := readerFrom.ReadFromWithConcurrency(wrappedSrc, opts.Client.MaxConcurrentRequestsPerFile)
|
||||||
|
if err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if err := ensureContext(ctx); err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
reportProgress(progress, written, total)
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyDownloadWithProgressContext(
|
||||||
|
ctx context.Context,
|
||||||
|
dst io.Writer,
|
||||||
|
src io.Reader,
|
||||||
|
bufSize int,
|
||||||
|
total int64,
|
||||||
|
progress func(float64),
|
||||||
|
opts resolvedSFTPTransferOptions,
|
||||||
|
) (int64, error) {
|
||||||
|
if !derefSFTPBool(opts.Client.ConcurrentReads, false) {
|
||||||
|
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
writerTo, ok := src.(io.WriterTo)
|
||||||
|
if !ok {
|
||||||
|
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ensureContext(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if progress != nil && total > 0 {
|
||||||
|
progress(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedDst := &sftpDownloadProgressWriter{
|
||||||
|
ctx: ctx,
|
||||||
|
writer: dst,
|
||||||
|
total: total,
|
||||||
|
progress: progress,
|
||||||
|
}
|
||||||
|
written, err := writerTo.WriteTo(wrappedDst)
|
||||||
|
if err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if err := ensureContext(ctx); err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
reportProgress(progress, written, total)
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sftpUploadProgressReader) Read(p []byte) (int, error) {
|
||||||
|
if err := ensureContext(r.ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if err := ensureContext(r.ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := r.reader.Read(p)
|
||||||
|
if n > 0 {
|
||||||
|
r.copied += int64(n)
|
||||||
|
reportQueuedTransferProgress(r.progress, r.copied, r.total)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sftpDownloadProgressWriter) Write(p []byte) (int, error) {
|
||||||
|
if err := ensureContext(w.ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
if err := ensureContext(w.ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := w.writer.Write(p)
|
||||||
|
if n > 0 {
|
||||||
|
w.copied += int64(n)
|
||||||
|
reportQueuedTransferProgress(w.progress, w.copied, w.total)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if n != len(p) {
|
||||||
|
return n, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
if err := ensureContext(w.ctx); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportQueuedTransferProgress(progress func(float64), copied int64, total int64) {
|
||||||
|
if progress == nil || total <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
percent := float64(copied) / float64(total) * 100
|
||||||
|
if percent >= 100 {
|
||||||
|
percent = 99
|
||||||
|
}
|
||||||
|
progress(percent)
|
||||||
|
}
|
||||||
|
|
||||||
func reportProgress(progress func(float64), copied int64, total int64) {
|
func reportProgress(progress func(float64), copied int64, total int64) {
|
||||||
if progress == nil {
|
if progress == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
+666
-14
@@ -1,6 +1,7 @@
|
|||||||
package starssh
|
package starssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,16 +12,521 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
||||||
opts := normalizeSFTPTransferOptions(nil)
|
opts := mustNormalizeSFTPTransferOptions(t, nil)
|
||||||
if !opts.AtomicUpload {
|
if !opts.AtomicUpload {
|
||||||
t.Fatal("expected atomic upload to default to enabled")
|
t.Fatal("expected atomic upload to default to enabled")
|
||||||
}
|
}
|
||||||
if !opts.AtomicDownload {
|
if !opts.AtomicDownload {
|
||||||
t.Fatal("expected atomic download to default to enabled")
|
t.Fatal("expected atomic download to default to enabled")
|
||||||
}
|
}
|
||||||
|
if opts.Client.ConcurrentWrites != nil {
|
||||||
|
t.Fatal("expected concurrent writes to default to unset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) {
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
if !opts.AtomicUpload {
|
||||||
|
t.Fatal("expected throughput preset to keep atomic upload enabled")
|
||||||
|
}
|
||||||
|
if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads {
|
||||||
|
t.Fatal("expected throughput preset to enable concurrent reads")
|
||||||
|
}
|
||||||
|
if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites {
|
||||||
|
t.Fatal("expected throughput preset to enable concurrent writes")
|
||||||
|
}
|
||||||
|
if opts.Client.MaxConcurrentRequestsPerFile != 32 {
|
||||||
|
t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) {
|
||||||
|
options := ThroughputSFTPTransferOptions()
|
||||||
|
options.AtomicUpload = SFTPBool(false)
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, &options)
|
||||||
|
|
||||||
|
err := validateSFTPUploadOptions(opts)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "atomic upload") {
|
||||||
|
t.Fatalf("expected atomic upload rejection, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) {
|
||||||
|
_, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
||||||
|
Client: SFTPClientOptions{MaxPacketSize: -1},
|
||||||
|
})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "max packet") {
|
||||||
|
t.Fatalf("expected max packet rejection, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
||||||
|
Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1},
|
||||||
|
})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "max concurrent") {
|
||||||
|
t.Fatalf("expected max concurrent rejection, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) {
|
||||||
|
options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769})
|
||||||
|
client := &sftp.Client{}
|
||||||
|
|
||||||
|
err := options[0](client)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "32KB") {
|
||||||
|
t.Fatalf("expected checked packet size rejection, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
||||||
|
client := newSFTPTestClient(t)
|
||||||
|
root := t.TempDir()
|
||||||
|
localPath := filepath.Join(root, "local.txt")
|
||||||
|
remotePath := filepath.Join(root, "remote.txt")
|
||||||
|
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write local file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
||||||
|
t.Fatalf("expected external client option rejection, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
localPath := filepath.Join(root, "local.txt")
|
||||||
|
remotePath := filepath.Join(root, "remote.txt")
|
||||||
|
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write local file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := newSFTPTestClient(t)
|
||||||
|
options := ThroughputSFTPTransferOptions()
|
||||||
|
var captured []sftp.ClientOption
|
||||||
|
oldNewClient := sftpNewClientFunc
|
||||||
|
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
||||||
|
captured = append([]sftp.ClientOption(nil), opts...)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sftpNewClientFunc = oldNewClient
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{Client: &ssh.Client{}}
|
||||||
|
if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil {
|
||||||
|
t.Fatalf("transfer out: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(captured) == 0 {
|
||||||
|
t.Fatal("expected managed SFTP client options to be passed to factory")
|
||||||
|
}
|
||||||
|
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
||||||
|
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
||||||
|
client := newSFTPTestClient(t)
|
||||||
|
root := t.TempDir()
|
||||||
|
srcPath := filepath.Join(root, "remote.txt")
|
||||||
|
dstPath := filepath.Join(root, "local.txt")
|
||||||
|
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write remote file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
||||||
|
t.Fatalf("expected external client option rejection, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
srcPath := filepath.Join(root, "remote.txt")
|
||||||
|
dstPath := filepath.Join(root, "local.txt")
|
||||||
|
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write remote file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := newSFTPTestClient(t)
|
||||||
|
options := ThroughputSFTPTransferOptions()
|
||||||
|
var captured []sftp.ClientOption
|
||||||
|
oldNewClient := sftpNewClientFunc
|
||||||
|
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
||||||
|
captured = append([]sftp.ClientOption(nil), opts...)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sftpNewClientFunc = oldNewClient
|
||||||
|
})
|
||||||
|
|
||||||
|
star := &StarSSH{Client: &ssh.Client{}}
|
||||||
|
if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil {
|
||||||
|
t.Fatalf("transfer in: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(captured) == 0 {
|
||||||
|
t.Fatal("expected managed SFTP client options to be passed to factory")
|
||||||
|
}
|
||||||
|
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
||||||
|
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
assertFileContent(t, dstPath, "payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) {
|
||||||
|
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
||||||
|
MaxPacketSize: 4096,
|
||||||
|
MaxConcurrentRequestsPerFile: 4,
|
||||||
|
ConcurrentWrites: SFTPBool(true),
|
||||||
|
}))
|
||||||
|
root := t.TempDir()
|
||||||
|
localPath := filepath.Join(root, "local.txt")
|
||||||
|
remotePath := filepath.Join(root, "remote.txt")
|
||||||
|
|
||||||
|
if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil {
|
||||||
|
t.Fatalf("write local file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil {
|
||||||
|
t.Fatalf("transfer out with client options: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(remotePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote file: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(data); got != strings.Repeat("payload-", 2048) {
|
||||||
|
t.Fatalf("unexpected remote payload length: got %d", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSFTPClientOptionsCanDownload(t *testing.T) {
|
||||||
|
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
||||||
|
MaxPacketSize: 4096,
|
||||||
|
MaxConcurrentRequestsPerFile: 4,
|
||||||
|
ConcurrentReads: SFTPBool(true),
|
||||||
|
}))
|
||||||
|
root := t.TempDir()
|
||||||
|
srcPath := filepath.Join(root, "remote.txt")
|
||||||
|
dstPath := filepath.Join(root, "local.txt")
|
||||||
|
|
||||||
|
payload := strings.Repeat("payload-", 2048)
|
||||||
|
if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil {
|
||||||
|
t.Fatalf("write remote file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil {
|
||||||
|
t.Fatalf("transfer in with client options: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFileContent(t, dstPath, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) {
|
||||||
|
dst := &spyConcurrentReadFrom{}
|
||||||
|
src := strings.NewReader("payload")
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy upload: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
if !dst.usedReadFrom {
|
||||||
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||||
|
}
|
||||||
|
if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile {
|
||||||
|
t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile)
|
||||||
|
}
|
||||||
|
if got := dst.buf.String(); got != "payload" {
|
||||||
|
t.Fatalf("unexpected copied payload: got %q", got)
|
||||||
|
}
|
||||||
|
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
||||||
|
t.Fatalf("expected final progress 100, got %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) {
|
||||||
|
dst := &spyConcurrentReadFrom{}
|
||||||
|
src := &chunkedReader{
|
||||||
|
reader: strings.NewReader("payload"),
|
||||||
|
chunkSize: 2,
|
||||||
|
}
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy upload: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sawIntermediate bool
|
||||||
|
for _, value := range progress {
|
||||||
|
if value > 0 && value < 100 {
|
||||||
|
sawIntermediate = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sawIntermediate {
|
||||||
|
t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dst := &spyConcurrentReadFrom{}
|
||||||
|
src := &cancelAfterReadReader{
|
||||||
|
reader: strings.NewReader("payload"),
|
||||||
|
chunkSize: 3,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
||||||
|
}
|
||||||
|
if !dst.usedReadFrom {
|
||||||
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||||
|
}
|
||||||
|
if written <= 0 || written >= int64(len("payload")) {
|
||||||
|
t.Fatalf("expected partial write before cancellation, got %d", written)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) {
|
||||||
|
copyErr := errors.New("write status failed")
|
||||||
|
dst := &spyConcurrentReadFrom{errAfterRead: copyErr}
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if !errors.Is(err, copyErr) {
|
||||||
|
t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("expected read byte count before write error, got %d", written)
|
||||||
|
}
|
||||||
|
for _, value := range progress {
|
||||||
|
if value >= 100 {
|
||||||
|
t.Fatalf("progress reported completion before write success: %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(progress) == 0 {
|
||||||
|
t.Fatal("expected queued progress before write error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) {
|
||||||
|
copyErr := errors.New("readfrom failed")
|
||||||
|
dst := &spyConcurrentReadFrom{err: copyErr}
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
_, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts)
|
||||||
|
if !errors.Is(err, copyErr) {
|
||||||
|
t.Fatalf("expected concurrent readfrom error, got %v", err)
|
||||||
|
}
|
||||||
|
if !dst.usedReadFrom {
|
||||||
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
||||||
|
oldCopy := sftpCopyWithProgressFunc
|
||||||
|
called := false
|
||||||
|
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||||
|
called = true
|
||||||
|
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sftpCopyWithProgressFunc = oldCopy
|
||||||
|
})
|
||||||
|
|
||||||
|
var dst bytes.Buffer
|
||||||
|
written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy upload: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected default path to use existing copy helper")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) {
|
||||||
|
src := &spyConcurrentWriteTo{payload: []byte("payload")}
|
||||||
|
var dst bytes.Buffer
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy download: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
if !src.usedWriteTo {
|
||||||
|
t.Fatal("expected concurrent WriteTo path to be used")
|
||||||
|
}
|
||||||
|
if got := dst.String(); got != "payload" {
|
||||||
|
t.Fatalf("unexpected copied payload: got %q", got)
|
||||||
|
}
|
||||||
|
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
||||||
|
t.Fatalf("expected final progress 100, got %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) {
|
||||||
|
src := &spyConcurrentWriteTo{
|
||||||
|
payload: []byte("payload"),
|
||||||
|
chunkSize: 2,
|
||||||
|
}
|
||||||
|
var dst bytes.Buffer
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy download: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sawIntermediate bool
|
||||||
|
for _, value := range progress {
|
||||||
|
if value > 0 && value < 100 {
|
||||||
|
sawIntermediate = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sawIntermediate {
|
||||||
|
t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
src := &spyConcurrentWriteTo{
|
||||||
|
payload: []byte("payload"),
|
||||||
|
chunkSize: 3,
|
||||||
|
cancelAfterWrite: cancel,
|
||||||
|
}
|
||||||
|
var dst bytes.Buffer
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
||||||
|
}
|
||||||
|
if !src.usedWriteTo {
|
||||||
|
t.Fatal("expected concurrent WriteTo path to be used")
|
||||||
|
}
|
||||||
|
if written <= 0 || written >= int64(len("payload")) {
|
||||||
|
t.Fatalf("expected partial write before cancellation, got %d", written)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) {
|
||||||
|
copyErr := errors.New("read status failed")
|
||||||
|
src := &spyConcurrentWriteTo{
|
||||||
|
payload: []byte("payload"),
|
||||||
|
errAfterWrite: copyErr,
|
||||||
|
}
|
||||||
|
var dst bytes.Buffer
|
||||||
|
var progress []float64
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||||
|
progress = append(progress, value)
|
||||||
|
}, opts)
|
||||||
|
if !errors.Is(err, copyErr) {
|
||||||
|
t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("expected local byte count before read error, got %d", written)
|
||||||
|
}
|
||||||
|
for _, value := range progress {
|
||||||
|
if value >= 100 {
|
||||||
|
t.Fatalf("progress reported completion before download success: %v", progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(progress) == 0 {
|
||||||
|
t.Fatal("expected queued progress before read error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) {
|
||||||
|
copyErr := errors.New("writeto failed")
|
||||||
|
src := &spyConcurrentWriteTo{err: copyErr}
|
||||||
|
var dst bytes.Buffer
|
||||||
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||||
|
|
||||||
|
_, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts)
|
||||||
|
if !errors.Is(err, copyErr) {
|
||||||
|
t.Fatalf("expected concurrent writeto error, got %v", err)
|
||||||
|
}
|
||||||
|
if !src.usedWriteTo {
|
||||||
|
t.Fatal("expected concurrent WriteTo path to be used")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
||||||
|
oldCopy := sftpCopyWithProgressFunc
|
||||||
|
called := false
|
||||||
|
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||||
|
called = true
|
||||||
|
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sftpCopyWithProgressFunc = oldCopy
|
||||||
|
})
|
||||||
|
|
||||||
|
var dst bytes.Buffer
|
||||||
|
src := &chunkedReader{
|
||||||
|
reader: strings.NewReader("payload"),
|
||||||
|
chunkSize: 2,
|
||||||
|
}
|
||||||
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copy download: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("payload")) {
|
||||||
|
t.Fatalf("unexpected written bytes: got %d", written)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected default path to use existing copy helper")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
||||||
@@ -47,7 +553,7 @@ func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
|||||||
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
||||||
})
|
})
|
||||||
|
|
||||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if !errors.Is(err, verifyErr) {
|
if !errors.Is(err, verifyErr) {
|
||||||
t.Fatalf("expected verify failure, got %v", err)
|
t.Fatalf("expected verify failure, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -82,7 +588,7 @@ func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
|
|||||||
t.Skipf("symlink unsupported: %v", err)
|
t.Skipf("symlink unsupported: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||||
t.Fatalf("expected symlink rejection, got %v", err)
|
t.Fatalf("expected symlink rejection, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -118,7 +624,7 @@ func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
|
|||||||
t.Fatalf("mkdir remote target: %v", err)
|
t.Fatalf("mkdir remote target: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||||
t.Fatalf("expected directory rejection, got %v", err)
|
t.Fatalf("expected directory rejection, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -149,7 +655,7 @@ func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
|||||||
t.Fatalf("chmod remote file: %v", err)
|
t.Fatalf("chmod remote file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||||
t.Fatalf("transfer out: %v", err)
|
t.Fatalf("transfer out: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,7 +676,7 @@ func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
|
|||||||
t.Fatalf("chmod local file: %v", err)
|
t.Fatalf("chmod local file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||||
t.Fatalf("transfer out: %v", err)
|
t.Fatalf("transfer out: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +696,7 @@ func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
|||||||
t.Fatalf("chmod remote file: %v", err)
|
t.Fatalf("chmod remote file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||||
t.Fatalf("transfer out bytes: %v", err)
|
t.Fatalf("transfer out bytes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,7 +728,7 @@ func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
|
|||||||
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
||||||
})
|
})
|
||||||
|
|
||||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if !errors.Is(err, verifyErr) {
|
if !errors.Is(err, verifyErr) {
|
||||||
t.Fatalf("expected verify failure, got %v", err)
|
t.Fatalf("expected verify failure, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -257,7 +763,7 @@ func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
|
|||||||
t.Skipf("symlink unsupported: %v", err)
|
t.Skipf("symlink unsupported: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||||
t.Fatalf("expected symlink rejection, got %v", err)
|
t.Fatalf("expected symlink rejection, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -286,7 +792,7 @@ func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
|
|||||||
t.Fatalf("mkdir local target: %v", err)
|
t.Fatalf("mkdir local target: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||||
t.Fatalf("expected directory rejection, got %v", err)
|
t.Fatalf("expected directory rejection, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -317,7 +823,7 @@ func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
|
|||||||
t.Fatalf("chmod local file: %v", err)
|
t.Fatalf("chmod local file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||||
t.Fatalf("transfer in: %v", err)
|
t.Fatalf("transfer in: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,7 +844,7 @@ func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
|
|||||||
t.Fatalf("chmod remote file: %v", err)
|
t.Fatalf("chmod remote file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||||
t.Fatalf("transfer in: %v", err)
|
t.Fatalf("transfer in: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -387,7 +893,7 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
|||||||
sftpCopyWithProgressFunc = oldCopy
|
sftpCopyWithProgressFunc = oldCopy
|
||||||
})
|
})
|
||||||
|
|
||||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||||
if !errors.Is(err, copyErr) {
|
if !errors.Is(err, copyErr) {
|
||||||
t.Fatalf("expected copy failure, got %v", err)
|
t.Fatalf("expected copy failure, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -405,7 +911,153 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
|||||||
assertNoTransferTemps(t, dstPath)
|
assertNoTransferTemps(t, dstPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
opts, err := normalizeSFTPTransferOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize sftp transfer options: %v", err)
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions {
|
||||||
|
return &options
|
||||||
|
}
|
||||||
|
|
||||||
|
type spyConcurrentReadFrom struct {
|
||||||
|
buf bytes.Buffer
|
||||||
|
concurrency int
|
||||||
|
usedReadFrom bool
|
||||||
|
err error
|
||||||
|
errAfterRead error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) {
|
||||||
|
return w.buf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) {
|
||||||
|
w.usedReadFrom = true
|
||||||
|
w.concurrency = concurrency
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
written, err := w.buf.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if w.errAfterRead != nil {
|
||||||
|
return written, w.errAfterRead
|
||||||
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type spyConcurrentWriteTo struct {
|
||||||
|
payload []byte
|
||||||
|
chunkSize int
|
||||||
|
usedWriteTo bool
|
||||||
|
err error
|
||||||
|
errAfterWrite error
|
||||||
|
cancelAfterWrite context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) {
|
||||||
|
if len(r.payload) == 0 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0, r.err
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.payload)
|
||||||
|
r.payload = r.payload[n:]
|
||||||
|
if len(r.payload) == 0 {
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
r.usedWriteTo = true
|
||||||
|
if r.err != nil {
|
||||||
|
return 0, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := r.payload
|
||||||
|
if len(payload) == 0 {
|
||||||
|
if r.errAfterWrite != nil {
|
||||||
|
return 0, r.errAfterWrite
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkSize := r.chunkSize
|
||||||
|
if chunkSize <= 0 || chunkSize > len(payload) {
|
||||||
|
chunkSize = len(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
var written int64
|
||||||
|
for len(payload) > 0 {
|
||||||
|
size := chunkSize
|
||||||
|
if size > len(payload) {
|
||||||
|
size = len(payload)
|
||||||
|
}
|
||||||
|
n, err := w.Write(payload[:size])
|
||||||
|
written += int64(n)
|
||||||
|
if r.cancelAfterWrite != nil && n > 0 {
|
||||||
|
r.cancelAfterWrite()
|
||||||
|
r.cancelAfterWrite = nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if n != size {
|
||||||
|
return written, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
payload = payload[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.errAfterWrite != nil {
|
||||||
|
return written, r.errAfterWrite
|
||||||
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkedReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
chunkSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *chunkedReader) Read(p []byte) (int, error) {
|
||||||
|
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
||||||
|
p = p[:r.chunkSize]
|
||||||
|
}
|
||||||
|
return r.reader.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cancelAfterReadReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
chunkSize int
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *cancelAfterReadReader) Read(p []byte) (int, error) {
|
||||||
|
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
||||||
|
p = p[:r.chunkSize]
|
||||||
|
}
|
||||||
|
n, err := r.reader.Read(p)
|
||||||
|
if n > 0 && r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
||||||
|
return newSFTPTestClientWithOptions(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
@@ -419,7 +1071,7 @@ func newSFTPTestClient(t *testing.T) *sftp.Client {
|
|||||||
serveErrCh <- server.Serve()
|
serveErrCh <- server.Serve()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := sftp.NewClientPipe(clientConn, clientConn)
|
client, err := sftp.NewClientPipe(clientConn, clientConn, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = server.Close()
|
_ = server.Close()
|
||||||
t.Fatalf("create sftp client: %v", err)
|
t.Fatalf("create sftp client: %v", err)
|
||||||
|
|||||||
@@ -0,0 +1,668 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
sshagent "golang.org/x/crypto/ssh/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||||
|
var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth")
|
||||||
|
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||||
|
|
||||||
|
type sshAgentTimeouts struct {
|
||||||
|
Dial time.Duration
|
||||||
|
Operation time.Duration
|
||||||
|
Forward time.Duration
|
||||||
|
Endpoint string
|
||||||
|
Resolved resolvedSSHAgentEndpoint
|
||||||
|
Debug SSHAgentDebugFunc
|
||||||
|
SkipFingerprints map[string]struct{}
|
||||||
|
SignFailure func(ssh.PublicKey, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentAuthAttempt struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
skipFingerprints map[string]struct{}
|
||||||
|
retryRequested bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultAuthOrder = []AuthMethodKind{
|
||||||
|
AuthMethodSSHAgent,
|
||||||
|
AuthMethodPrivateKey,
|
||||||
|
AuthMethodPassword,
|
||||||
|
AuthMethodKeyboardInteractive,
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveSSHAgentTimeout(info LoginInput) time.Duration {
|
||||||
|
switch {
|
||||||
|
case info.SSHAgentTimeout < 0:
|
||||||
|
return 0
|
||||||
|
case info.SSHAgentTimeout > 0:
|
||||||
|
return info.SSHAgentTimeout
|
||||||
|
default:
|
||||||
|
return defaultSSHAgentTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts {
|
||||||
|
return sshAgentTimeouts{
|
||||||
|
Dial: effectiveDialTimeout(info),
|
||||||
|
Operation: effectiveSSHAgentTimeout(info),
|
||||||
|
Forward: effectiveSSHAgentForwardTimeout(info),
|
||||||
|
Endpoint: info.IdentityAgent,
|
||||||
|
Debug: info.SSHAgentDebug,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration {
|
||||||
|
if info.SSHAgentForwardTimeout > 0 {
|
||||||
|
return info.SSHAgentForwardTimeout
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||||
|
return buildAuthMethodsWithAgentAttempt(info, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) {
|
||||||
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := make([]ssh.AuthMethod, 0, len(order))
|
||||||
|
var agentErr error
|
||||||
|
var cleanupFuncs []func()
|
||||||
|
|
||||||
|
for _, methodKind := range order {
|
||||||
|
switch methodKind {
|
||||||
|
case AuthMethodPrivateKey:
|
||||||
|
method, err := buildPrivateKeyAuthMethod(info, agentAttempt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if method != nil {
|
||||||
|
auth = append(auth, method)
|
||||||
|
}
|
||||||
|
case AuthMethodPassword:
|
||||||
|
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt)
|
||||||
|
if method != nil {
|
||||||
|
auth = append(auth, method)
|
||||||
|
}
|
||||||
|
case AuthMethodKeyboardInteractive:
|
||||||
|
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt)
|
||||||
|
if method != nil {
|
||||||
|
auth = append(auth, method)
|
||||||
|
}
|
||||||
|
case AuthMethodSSHAgent:
|
||||||
|
if info.DisableSSHAgent {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
timeouts := effectiveSSHAgentTimeouts(info)
|
||||||
|
if agentAttempt != nil {
|
||||||
|
timeouts.SkipFingerprints = agentAttempt.skipSnapshot()
|
||||||
|
timeouts.SignFailure = agentAttempt.recordSignFailure
|
||||||
|
}
|
||||||
|
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts)
|
||||||
|
if err != nil {
|
||||||
|
agentErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if agentMethod != nil {
|
||||||
|
auth = append(auth, agentMethod)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanupFuncs = append(cleanupFuncs, cleanup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(auth) == 0 {
|
||||||
|
if agentErr != nil {
|
||||||
|
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
|
||||||
|
}
|
||||||
|
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth, composeCleanup(cleanupFuncs...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
|
||||||
|
if len(order) == 0 {
|
||||||
|
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make([]AuthMethodKind, 0, len(order))
|
||||||
|
seen := make(map[AuthMethodKind]struct{}, len(order))
|
||||||
|
for _, raw := range order {
|
||||||
|
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
|
||||||
|
if kind == "" {
|
||||||
|
return nil, errors.New("auth order contains an empty auth method")
|
||||||
|
}
|
||||||
|
if !isSupportedAuthMethodKind(kind) {
|
||||||
|
return nil, fmt.Errorf("unsupported auth method %q", raw)
|
||||||
|
}
|
||||||
|
if _, exists := seen[kind]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[kind] = struct{}{}
|
||||||
|
normalized = append(normalized, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return nil, errors.New("auth order is empty")
|
||||||
|
}
|
||||||
|
return normalized, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
|
||||||
|
switch kind {
|
||||||
|
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool {
|
||||||
|
if info.DisableSSHAgent {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, methodKind := range order {
|
||||||
|
if methodKind == AuthMethodSSHAgent {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) {
|
||||||
|
if strings.TrimSpace(info.Prikey) == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes := []byte(info.Prikey)
|
||||||
|
if info.PrikeyPwd == "" {
|
||||||
|
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) {
|
||||||
|
return func() ([]ssh.Signer, error) {
|
||||||
|
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []ssh.Signer{signer}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod {
|
||||||
|
if password == "" && callback == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ssh.PasswordCallback(func() (string, error) {
|
||||||
|
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if password != "" {
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
return callback()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKeyboardInteractiveAuthMethod(
|
||||||
|
password string,
|
||||||
|
passwordCallback func() (string, error),
|
||||||
|
challenge ssh.KeyboardInteractiveChallenge,
|
||||||
|
agentAttempt *sshAgentAuthAttempt,
|
||||||
|
) ssh.AuthMethod {
|
||||||
|
if challenge != nil {
|
||||||
|
return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||||
|
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return challenge(user, instruction, questions, echos)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if password == "" && passwordCallback == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||||
|
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(questions) == 0 {
|
||||||
|
return []string{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
answer := password
|
||||||
|
if answer == "" {
|
||||||
|
var err error
|
||||||
|
answer, err = passwordCallback()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
answers := make([]string, len(questions))
|
||||||
|
for i := range questions {
|
||||||
|
answers[i] = answer
|
||||||
|
}
|
||||||
|
return answers, nil
|
||||||
|
}
|
||||||
|
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||||
|
conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errSSHAgentUnavailable) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation)
|
||||||
|
|
||||||
|
started := time.Now()
|
||||||
|
signers, err := sshagent.NewClient(conn).Signers()
|
||||||
|
err = normalizeSSHAgentError(err)
|
||||||
|
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
|
||||||
|
Step: "auth",
|
||||||
|
Source: resolved.Source,
|
||||||
|
Endpoint: resolved.Endpoint,
|
||||||
|
Network: resolved.Network,
|
||||||
|
Phase: "list",
|
||||||
|
Status: debugStatus(err),
|
||||||
|
Duration: time.Since(started),
|
||||||
|
KeyCount: len(signers),
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if len(signers) == 0 {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, nil, errors.New("ssh-agent has no loaded keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
timeouts.Resolved = resolved
|
||||||
|
orderedSigners := orderSSHAgentSigners(signers)
|
||||||
|
filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts)
|
||||||
|
if len(filteredSigners) == 0 {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, nil, errors.New("ssh-agent has no usable keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.PublicKeys(filteredSigners...), func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer {
|
||||||
|
type orderedSigner struct {
|
||||||
|
signer ssh.Signer
|
||||||
|
index int
|
||||||
|
score int
|
||||||
|
comment string
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := make([]orderedSigner, 0, len(signers))
|
||||||
|
for index, signer := range signers {
|
||||||
|
if signer == nil || signer.PublicKey() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ordered = append(ordered, orderedSigner{
|
||||||
|
signer: signer,
|
||||||
|
index: index,
|
||||||
|
score: sshAgentSignerPriority(signer),
|
||||||
|
comment: sshAgentSignerComment(signer),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.SliceStable(ordered, func(i, j int) bool {
|
||||||
|
if ordered[i].score != ordered[j].score {
|
||||||
|
return ordered[i].score > ordered[j].score
|
||||||
|
}
|
||||||
|
return ordered[i].index < ordered[j].index
|
||||||
|
})
|
||||||
|
|
||||||
|
result := make([]ssh.Signer, 0, len(ordered))
|
||||||
|
for _, item := range ordered {
|
||||||
|
result = append(result, item.signer)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshAgentSignerComment(signer ssh.Signer) string {
|
||||||
|
if signer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if key, ok := signer.PublicKey().(*sshagent.Key); ok {
|
||||||
|
return key.Comment
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshAgentSignerPriority(signer ssh.Signer) int {
|
||||||
|
comment := strings.TrimSpace(sshAgentSignerComment(signer))
|
||||||
|
if comment == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
score := 0
|
||||||
|
if priority, ok := parseSSHAgentSignerPriority(comment); ok {
|
||||||
|
score += 100000 + priority*1000
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(comment)
|
||||||
|
if strings.Contains(lower, "current") {
|
||||||
|
score += 400
|
||||||
|
}
|
||||||
|
if strings.Contains(lower, "cardno:") {
|
||||||
|
score += 300
|
||||||
|
}
|
||||||
|
if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") {
|
||||||
|
score += 100
|
||||||
|
}
|
||||||
|
if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") {
|
||||||
|
score += 50
|
||||||
|
}
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSSHAgentSignerPriority(comment string) (int, bool) {
|
||||||
|
lower := strings.ToLower(comment)
|
||||||
|
index := strings.Index(lower, "priority=")
|
||||||
|
if index < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value := strings.TrimSpace(comment[index+len("priority="):])
|
||||||
|
if value == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
end := 0
|
||||||
|
for end < len(value) {
|
||||||
|
ch := value[end]
|
||||||
|
if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') {
|
||||||
|
end++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if end == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
priority, err := strconv.Atoi(value[:end])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return priority, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer {
|
||||||
|
filteredSigners := make([]ssh.Signer, 0, len(signers))
|
||||||
|
for _, signer := range signers {
|
||||||
|
if signer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
publicKey := signer.PublicKey()
|
||||||
|
if publicKey == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if timeouts.SignFailure == nil && timeouts.Debug == nil {
|
||||||
|
filteredSigners = append(filteredSigners, signer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{
|
||||||
|
Resolved: timeouts.Resolved,
|
||||||
|
Debug: timeouts.Debug,
|
||||||
|
SignFailure: timeouts.SignFailure,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
return filteredSigners
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSSHAgentAuthAttempt() *sshAgentAuthAttempt {
|
||||||
|
return &sshAgentAuthAttempt{
|
||||||
|
skipFingerprints: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *sshAgentAuthAttempt) begin() {
|
||||||
|
if a == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.retryRequested = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
if len(a.skipFingerprints) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snapshot := make(map[string]struct{}, len(a.skipFingerprints))
|
||||||
|
for fingerprint := range a.skipFingerprints {
|
||||||
|
snapshot[fingerprint] = struct{}{}
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) {
|
||||||
|
_ = err
|
||||||
|
if a == nil || publicKey == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.skipFingerprint(ssh.FingerprintSHA256(publicKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) {
|
||||||
|
if a == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.retryRequested = true
|
||||||
|
if fingerprint != "" {
|
||||||
|
a.skipFingerprints[fingerprint] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *sshAgentAuthAttempt) shouldRetry() bool {
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
return a.retryRequested
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error {
|
||||||
|
if agentAttempt != nil && agentAttempt.shouldRetry() {
|
||||||
|
return errRetrySSHAgentAuth
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentRetrySigner struct {
|
||||||
|
signer ssh.Signer
|
||||||
|
publicKey ssh.PublicKey
|
||||||
|
options sshAgentSignerOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentRetryAlgorithmSigner struct {
|
||||||
|
sshAgentRetrySigner
|
||||||
|
algorithmSigner ssh.AlgorithmSigner
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentRetryMultiAlgorithmSigner struct {
|
||||||
|
sshAgentRetryAlgorithmSigner
|
||||||
|
multiAlgorithmSigner ssh.MultiAlgorithmSigner
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshAgentSignerOptions struct {
|
||||||
|
Resolved resolvedSSHAgentEndpoint
|
||||||
|
Debug SSHAgentDebugFunc
|
||||||
|
SignFailure func(ssh.PublicKey, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer {
|
||||||
|
return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure})
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer {
|
||||||
|
publicKey := signer.PublicKey()
|
||||||
|
base := sshAgentRetrySigner{
|
||||||
|
signer: signer,
|
||||||
|
publicKey: publicKey,
|
||||||
|
options: options,
|
||||||
|
}
|
||||||
|
if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok {
|
||||||
|
return &sshAgentRetryMultiAlgorithmSigner{
|
||||||
|
sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{
|
||||||
|
sshAgentRetrySigner: base,
|
||||||
|
algorithmSigner: multiAlgorithmSigner,
|
||||||
|
},
|
||||||
|
multiAlgorithmSigner: multiAlgorithmSigner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok {
|
||||||
|
return &sshAgentRetryAlgorithmSigner{
|
||||||
|
sshAgentRetrySigner: base,
|
||||||
|
algorithmSigner: algorithmSigner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &base
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey {
|
||||||
|
return s.publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||||
|
started := time.Now()
|
||||||
|
signature, err := s.signer.Sign(rand, data)
|
||||||
|
return signature, s.finishSign(started, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error {
|
||||||
|
err = normalizeSSHAgentError(err)
|
||||||
|
s.logSignDebug(started, err)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.options.SignFailure != nil {
|
||||||
|
s.options.SignFailure(s.publicKey, err)
|
||||||
|
return wrapSSHAgentSignError(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) {
|
||||||
|
if s == nil || s.options.Debug == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{
|
||||||
|
Step: "auth",
|
||||||
|
Source: s.options.Resolved.Source,
|
||||||
|
Endpoint: s.options.Resolved.Endpoint,
|
||||||
|
Network: s.options.Resolved.Network,
|
||||||
|
Phase: "sign",
|
||||||
|
Status: debugStatus(err),
|
||||||
|
Duration: time.Since(started),
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||||
|
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil)
|
||||||
|
started := time.Now()
|
||||||
|
signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
||||||
|
return signature, s.finishSign(started, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string {
|
||||||
|
return s.multiAlgorithmSigner.Algorithms()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||||
|
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms())
|
||||||
|
started := time.Now()
|
||||||
|
signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
||||||
|
return signature, s.finishSign(started, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string {
|
||||||
|
if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA {
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
if len(algorithms) == 0 {
|
||||||
|
return ssh.KeyAlgoRSASHA256
|
||||||
|
}
|
||||||
|
for _, algorithm := range algorithms {
|
||||||
|
if algorithm == ssh.KeyAlgoRSA {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 {
|
||||||
|
return algorithm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requested
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentSignError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeCleanup(funcs ...func()) func() {
|
||||||
|
if len(funcs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return func() {
|
||||||
|
for i := len(funcs) - 1; i >= 0; i-- {
|
||||||
|
if funcs[i] != nil {
|
||||||
|
funcs[i]()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
|
||||||
|
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
|
||||||
|
|
||||||
|
type sshAgentDialOptions struct {
|
||||||
|
Endpoint string
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolvedSSHAgentEndpoint struct {
|
||||||
|
Endpoint string
|
||||||
|
Source string
|
||||||
|
Network string
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineAgentConn struct {
|
||||||
|
net.Conn
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
|
||||||
|
endpoint := strings.TrimSpace(options.Endpoint)
|
||||||
|
if endpoint != "" {
|
||||||
|
return resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Source: "identity-agent",
|
||||||
|
Network: defaultSSHAgentNetwork(endpoint),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
||||||
|
if endpoint != "" {
|
||||||
|
return resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Source: "SSH_AUTH_SOCK",
|
||||||
|
Network: defaultSSHAgentNetwork(endpoint),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultSSHAgentEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||||||
|
resolved, err := resolveSSHAgentEndpoint(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resolvedSSHAgentEndpoint{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
|
||||||
|
if isTimeoutError(err) {
|
||||||
|
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, resolved, err
|
||||||
|
}
|
||||||
|
return conn, resolved, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||||||
|
options := sshAgentDialOptions{
|
||||||
|
Endpoint: timeouts.Endpoint,
|
||||||
|
Timeout: timeouts.Dial,
|
||||||
|
}
|
||||||
|
started := time.Now()
|
||||||
|
conn, resolved, err := dialSSHAgent(options)
|
||||||
|
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
|
||||||
|
Step: step,
|
||||||
|
Source: resolved.Source,
|
||||||
|
Endpoint: resolved.Endpoint,
|
||||||
|
Network: resolved.Network,
|
||||||
|
Phase: "dial",
|
||||||
|
Status: debugStatus(err),
|
||||||
|
Duration: time.Since(started),
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
return conn, resolved, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
|
||||||
|
if debug == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
debug(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugStatus(err error) string {
|
||||||
|
if err != nil {
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
return "ok"
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
|
||||||
|
if conn == nil || timeout <= 0 {
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
return &deadlineAgentConn{Conn: conn, timeout: timeout}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
|
||||||
|
c.setDeadline()
|
||||||
|
n, err := c.Conn.Read(p)
|
||||||
|
return n, wrapSSHAgentConnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
|
||||||
|
c.setDeadline()
|
||||||
|
n, err := c.Conn.Write(p)
|
||||||
|
return n, wrapSSHAgentConnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *deadlineAgentConn) setDeadline() {
|
||||||
|
if c == nil || c.timeout <= 0 || c.Conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeoutError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
return errors.As(err, &netErr) && netErr.Timeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapSSHAgentConnError(err error) error {
|
||||||
|
if isTimeoutError(err) {
|
||||||
|
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSSHAgentError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrSSHAgentTimeout) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
|
||||||
|
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
+10
-7
@@ -4,16 +4,19 @@ package starssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
|
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
|
||||||
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
return resolvedSSHAgentEndpoint{}, errSSHAgentUnavailable
|
||||||
if agentSock == "" {
|
}
|
||||||
return nil, errSSHAgentUnavailable
|
|
||||||
}
|
func defaultSSHAgentNetwork(endpoint string) string {
|
||||||
|
return "unix"
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||||
|
agentSock := resolved.Endpoint
|
||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
return net.DialTimeout("unix", agentSock, timeout)
|
return net.DialTimeout("unix", agentSock, timeout)
|
||||||
}
|
}
|
||||||
|
|||||||
+218
-17
@@ -3,10 +3,16 @@
|
|||||||
package starssh
|
package starssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,22 +22,40 @@ import (
|
|||||||
|
|
||||||
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
|
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
|
||||||
|
|
||||||
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
|
var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file")
|
||||||
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
|
||||||
if agentSock != "" {
|
type gpgSocketInfo struct {
|
||||||
return dialWindowsSSHAgentEndpoint(agentSock, timeout)
|
port uint16
|
||||||
}
|
nonce []byte
|
||||||
return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true)
|
cygwin bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) {
|
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
|
||||||
if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
|
return resolvedSSHAgentEndpoint{
|
||||||
return dialWindowsNamedPipe(pipePath, timeout, false)
|
Endpoint: defaultWindowsSSHAgentPipe,
|
||||||
|
Source: "platform-default",
|
||||||
|
Network: "windows-pipe",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultSSHAgentNetwork(endpoint string) string {
|
||||||
|
if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
|
||||||
|
return "windows-pipe"
|
||||||
}
|
}
|
||||||
if timeout > 0 {
|
if isAgentSSHSocketPath(endpoint) {
|
||||||
return net.DialTimeout("unix", endpoint, timeout)
|
return "gpg-socket"
|
||||||
}
|
}
|
||||||
return net.Dial("unix", endpoint)
|
return "unix"
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||||
|
if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok {
|
||||||
|
return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default")
|
||||||
|
}
|
||||||
|
if isAgentSSHSocketPath(resolved.Endpoint) {
|
||||||
|
return dialWindowsGPGSocketFile(resolved.Endpoint, timeout)
|
||||||
|
}
|
||||||
|
return dialWindowsUnixAgent(resolved.Endpoint, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
|
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
|
||||||
@@ -42,11 +66,7 @@ func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFo
|
|||||||
}
|
}
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := winio.DialPipeContext(ctx, path)
|
return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound)
|
||||||
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
|
|
||||||
return nil, errSSHAgentUnavailable
|
|
||||||
}
|
|
||||||
return conn, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
|
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
|
||||||
@@ -68,3 +88,184 @@ func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
|
|||||||
func isWindowsPipeUnavailable(err error) bool {
|
func isWindowsPipeUnavailable(err error) bool {
|
||||||
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
|
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
if timeout > 0 {
|
||||||
|
return net.DialTimeout("unix", endpoint, timeout)
|
||||||
|
}
|
||||||
|
return net.Dial("unix", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
cancel := func() {}
|
||||||
|
if timeout > 0 {
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) {
|
||||||
|
if path == "" {
|
||||||
|
return nil, fmt.Errorf("gpg agent endpoint is empty")
|
||||||
|
}
|
||||||
|
if depth > 8 {
|
||||||
|
return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if target, ok := parseGPGAssuanSocketRedirect(data); ok {
|
||||||
|
target = resolveGPGSocketRedirectTarget(path, target)
|
||||||
|
if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok {
|
||||||
|
return dialWindowsNamedPipeContext(ctx, pipePath, false)
|
||||||
|
}
|
||||||
|
return dialWindowsGPGSocketFileDepth(ctx, target, depth+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := parseGPGSocketInfo(path, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dialWindowsGPGSocketInfo(ctx, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) {
|
||||||
|
var dialer net.Dialer
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port))))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(info.nonce); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if info.cygwin {
|
||||||
|
var nonce [16]byte
|
||||||
|
if _, err := io.ReadFull(conn, nonce[:]); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var credential [8]byte
|
||||||
|
binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid()))
|
||||||
|
if _, err := conn.Write(credential[:]); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(conn, credential[:]); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(time.Time{})
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveGPGSocketRedirectTarget(source string, target string) string {
|
||||||
|
target = strings.TrimSpace(target)
|
||||||
|
if target == "" || filepath.IsAbs(target) {
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
if _, ok := normalizeWindowsSSHAgentPipe(target); ok {
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
return filepath.Join(filepath.Dir(source), target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) {
|
||||||
|
if info, ok := parseGPGAssuanSocketInfo(data); ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
if info, ok := parseGPGCygwinSocketInfo(data); ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGPGAssuanSocketRedirect(data []byte) (string, bool) {
|
||||||
|
text := strings.ReplaceAll(string(data), "\r\n", "\n")
|
||||||
|
text = strings.TrimSuffix(text, "\n")
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
if len(lines) != 2 || lines[0] != "%Assuan%" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
target, ok := strings.CutPrefix(lines[1], "socket=")
|
||||||
|
if !ok || strings.TrimSpace(target) == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return os.ExpandEnv(target), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) {
|
||||||
|
newline := bytes.IndexByte(data, '\n')
|
||||||
|
if newline <= 0 || len(data)-newline-1 != 16 {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16)
|
||||||
|
if err != nil || port64 == 0 {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
nonce := make([]byte, 16)
|
||||||
|
copy(nonce, data[newline+1:])
|
||||||
|
return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) {
|
||||||
|
if !bytes.HasPrefix(data, []byte("!<socket >")) {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00"))
|
||||||
|
if len(fields) != 3 || fields[1] != "s" {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
port64, err := strconv.ParseUint(fields[0], 10, 16)
|
||||||
|
if err != nil || port64 == 0 {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
hexParts := strings.Split(fields[2], "-")
|
||||||
|
if len(hexParts) != 4 {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
nonce := make([]byte, 0, 16)
|
||||||
|
for _, part := range hexParts {
|
||||||
|
if len(part) != 8 {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
value, err := strconv.ParseUint(part, 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return gpgSocketInfo{}, false
|
||||||
|
}
|
||||||
|
var chunk [4]byte
|
||||||
|
binary.LittleEndian.PutUint32(chunk[:], uint32(value))
|
||||||
|
nonce = append(nonce, chunk[:]...)
|
||||||
|
}
|
||||||
|
return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAgentSSHSocketPath(endpoint string) bool {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(endpoint))
|
||||||
|
return strings.HasSuffix(normalized, "s.gpg-agent.ssh")
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
conn, err := winio.DialPipeContext(ctx, path)
|
||||||
|
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
|
||||||
|
return nil, errSSHAgentUnavailable
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,152 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseGPGAssuanSocketInfo(t *testing.T) {
|
||||||
|
info, ok := parseGPGAssuanSocketInfo([]byte("7247\n0123456789abcdef"))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected Assuan socket info to parse")
|
||||||
|
}
|
||||||
|
if info.port != 7247 || string(info.nonce) != "0123456789abcdef" || info.cygwin {
|
||||||
|
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGPGCygwinSocketInfo(t *testing.T) {
|
||||||
|
info, ok := parseGPGCygwinSocketInfo([]byte("!<socket >7247 s 00000001-02030405-06070809-0a0b0c0d\x00"))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected Cygwin socket info to parse")
|
||||||
|
}
|
||||||
|
want := []byte{1, 0, 0, 0, 5, 4, 3, 2, 9, 8, 7, 6, 13, 12, 11, 10}
|
||||||
|
if info.port != 7247 || string(info.nonce) != string(want) || !info.cygwin {
|
||||||
|
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGPGAssuanSocketRedirect(t *testing.T) {
|
||||||
|
t.Setenv("STARSSH_TEST_PIPE", `\\.\pipe\openssh-ssh-agent`)
|
||||||
|
target, ok := parseGPGAssuanSocketRedirect([]byte("%Assuan%\r\nsocket=${STARSSH_TEST_PIPE}\r\n"))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected Assuan redirect to parse")
|
||||||
|
}
|
||||||
|
if target != `\\.\pipe\openssh-ssh-agent` {
|
||||||
|
t.Fatalf("target=%q", target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadInvalidAgentSSHSocketReturnsGPGSocketError(t *testing.T) {
|
||||||
|
path := t.TempDir() + "/S.gpg-agent.ssh"
|
||||||
|
if err := os.WriteFile(path, []byte("not a socket info file"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write socket file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: path,
|
||||||
|
Source: "SSH_AUTH_SOCK",
|
||||||
|
Network: defaultSSHAgentNetwork(path),
|
||||||
|
}, 0)
|
||||||
|
if !errors.Is(err, errInvalidGPGSocketInfo) {
|
||||||
|
t.Fatalf("err=%v want errInvalidGPGSocketInfo", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMissingAgentSSHSocketReturnsReadError(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||||
|
|
||||||
|
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: path,
|
||||||
|
Source: "identity-agent",
|
||||||
|
Network: defaultSSHAgentNetwork(path),
|
||||||
|
}, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected missing GPG socket file error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
t.Fatalf("err=%v want os.ErrNotExist", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnreadableAgentSSHSocketReturnsReadError(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||||
|
if err := os.Mkdir(path, 0o700); err != nil {
|
||||||
|
t.Fatalf("mkdir socket path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||||
|
Endpoint: path,
|
||||||
|
Source: "identity-agent",
|
||||||
|
Network: defaultSSHAgentNetwork(path),
|
||||||
|
}, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected unreadable GPG socket file error")
|
||||||
|
}
|
||||||
|
if errors.Is(err, errInvalidGPGSocketInfo) {
|
||||||
|
t.Fatalf("err=%v should expose read failure before parse", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialWindowsGPGSocketFilePerformsNonceHandshake(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
type handshakeResult struct {
|
||||||
|
nonce []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan handshakeResult, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
resultCh <- handshakeResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
nonce := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(conn, nonce); err != nil {
|
||||||
|
resultCh <- handshakeResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resultCh <- handshakeResult{nonce: append([]byte(nil), nonce...)}
|
||||||
|
}()
|
||||||
|
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||||
|
if err := os.WriteFile(socketPath, []byte(strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)+"\n0123456789abcdef"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write socket file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialWindowsGPGSocketFile(socketPath, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dialWindowsGPGSocketFile: %v", err)
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
var result handshakeResult
|
||||||
|
select {
|
||||||
|
case result = <-resultCh:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("listener did not accept GPG socket connection")
|
||||||
|
}
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("listener handshake error: %v", result.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(result.nonce, []byte("0123456789abcdef")) {
|
||||||
|
t.Fatalf("nonce=%q", result.nonce)
|
||||||
|
}
|
||||||
|
}
|
||||||
+2
-2
@@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: info.Timeout,
|
Timeout: effectiveDialTimeout(info),
|
||||||
}
|
}
|
||||||
return dialer.DialContext
|
return dialer.DialContext
|
||||||
}
|
}
|
||||||
@@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialContext := resolveDialContext(info)
|
dialContext := resolveDialContext(info)
|
||||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||||
if proxyConfig != nil {
|
if proxyConfig != nil {
|
||||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
defaultSSHPort = 22
|
defaultSSHPort = 22
|
||||||
defaultLoginTimeout = 5 * time.Second
|
defaultLoginTimeout = 5 * time.Second
|
||||||
|
defaultSSHAgentTimeout = 2 * time.Minute
|
||||||
defaultKeepAliveTimeout = 3 * time.Second
|
defaultKeepAliveTimeout = 3 * time.Second
|
||||||
defaultShellPollInterval = 120 * time.Millisecond
|
defaultShellPollInterval = 120 * time.Millisecond
|
||||||
defaultShellSetupDelay = 200 * time.Millisecond
|
defaultShellSetupDelay = 200 * time.Millisecond
|
||||||
@@ -58,6 +59,20 @@ const (
|
|||||||
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
|
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SSHAgentDebugFunc func(SSHAgentDebugEvent)
|
||||||
|
|
||||||
|
type SSHAgentDebugEvent struct {
|
||||||
|
Step string
|
||||||
|
Source string
|
||||||
|
Endpoint string
|
||||||
|
Network string
|
||||||
|
Phase string
|
||||||
|
Status string
|
||||||
|
Duration time.Duration
|
||||||
|
KeyCount int
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
type StarSSH struct {
|
type StarSSH struct {
|
||||||
stateMu sync.RWMutex
|
stateMu sync.RWMutex
|
||||||
Client *ssh.Client
|
Client *ssh.Client
|
||||||
@@ -92,16 +107,37 @@ type LoginInput struct {
|
|||||||
DisableSSHAgent bool
|
DisableSSHAgent bool
|
||||||
ForwardSSHAgent bool
|
ForwardSSHAgent bool
|
||||||
AuthOrder []AuthMethodKind
|
AuthOrder []AuthMethodKind
|
||||||
Addr string
|
// IdentityAgent overrides the local ssh-agent endpoint used for authentication
|
||||||
Port int
|
// and agent forwarding. Empty uses SSH_AUTH_SOCK, or the platform default where
|
||||||
Timeout time.Duration
|
// one exists.
|
||||||
DialContext DialContextFunc
|
IdentityAgent string
|
||||||
Proxy *ProxyConfig
|
Addr string
|
||||||
Jump *LoginInput
|
Port int
|
||||||
KeepAliveInterval time.Duration
|
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||||
KeepAliveTimeout time.Duration
|
// already been established. Zero means no authentication timeout.
|
||||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
Timeout time.Duration
|
||||||
BannerCallback func(string) error
|
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||||
|
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||||
|
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||||
|
DialTimeout time.Duration
|
||||||
|
// SSHAgentTimeout limits ssh-agent protocol operations such as listing keys and
|
||||||
|
// signing challenges. Zero uses the package default, and negative disables the
|
||||||
|
// per-operation deadline. This is intentionally separate from Timeout and
|
||||||
|
// DialTimeout because hardware-backed agents may require a PIN or touch confirmation.
|
||||||
|
SSHAgentTimeout time.Duration
|
||||||
|
// SSHAgentForwardTimeout limits idle reads and writes on forwarded agent
|
||||||
|
// channels. Zero or negative leaves forwarded channels without an idle deadline.
|
||||||
|
SSHAgentForwardTimeout time.Duration
|
||||||
|
// SSHAgentDebug receives structured ssh-agent dial/protocol events. It is nil by
|
||||||
|
// default and must not log private key material.
|
||||||
|
SSHAgentDebug SSHAgentDebugFunc
|
||||||
|
DialContext DialContextFunc
|
||||||
|
Proxy *ProxyConfig
|
||||||
|
Jump *LoginInput
|
||||||
|
KeepAliveInterval time.Duration
|
||||||
|
KeepAliveTimeout time.Duration
|
||||||
|
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||||
|
BannerCallback func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
||||||
|
|||||||
Reference in New Issue
Block a user