starssh/sshagent_conn.go

159 lines
3.4 KiB
Go
Raw Permalink Normal View History

package starssh
import (
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
type sshAgentDialOptions struct {
Endpoint string
Timeout time.Duration
}
type resolvedSSHAgentEndpoint struct {
Endpoint string
Source string
Network string
}
type deadlineAgentConn struct {
net.Conn
timeout time.Duration
}
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
endpoint := strings.TrimSpace(options.Endpoint)
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "SSH_AUTH_SOCK",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
return defaultSSHAgentEndpoint()
}
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
resolved, err := resolveSSHAgentEndpoint(options)
if err != nil {
return nil, resolvedSSHAgentEndpoint{}, err
}
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
if isTimeoutError(err) {
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
if err != nil {
return nil, resolved, err
}
return conn, resolved, nil
}
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
options := sshAgentDialOptions{
Endpoint: timeouts.Endpoint,
Timeout: timeouts.Dial,
}
started := time.Now()
conn, resolved, err := dialSSHAgent(options)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: step,
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "dial",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
return conn, resolved, err
}
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
if debug == nil {
return
}
debug(event)
}
func debugStatus(err error) string {
if err != nil {
return "error"
}
return "ok"
}
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
if conn == nil || timeout <= 0 {
return conn
}
return &deadlineAgentConn{Conn: conn, timeout: timeout}
}
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Read(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Write(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) setDeadline() {
if c == nil || c.timeout <= 0 || c.Conn == nil {
return
}
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
}
func isTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, os.ErrDeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
func wrapSSHAgentConnError(err error) error {
if isTimeoutError(err) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}
func normalizeSSHAgentError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, ErrSSHAgentTimeout) {
return err
}
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}