package starssh import ( "context" "encoding/base64" "errors" "fmt" "net" "os" "strings" "time" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod var defaultAuthOrder = []AuthMethodKind{ AuthMethodSSHAgent, AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, } func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil } func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) { return loginWithContext(ctx, info) } func Login(info LoginInput) (*StarSSH, error) { return LoginContext(context.Background(), info) } func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { info = normalizeLoginInput(info) if info.HostKeyCallback == nil { return nil, ErrHostKeyCallbackRequired } authTimeout := effectiveLoginTimeout(info) loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout) defer cancel() sshInfo := &StarSSH{ LoginInfo: info, } auth, authCleanup, err := buildAuthMethods(info) if err != nil { return nil, err } if authCleanup != nil { defer authCleanup() } hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { sshInfo.PublicKey = key sshInfo.RemoteAddr = remote sshInfo.Hostname = hostname return info.HostKeyCallback(hostname, remote, key) } bannerCallback := func(banner string) error { sshInfo.Banner = banner if info.BannerCallback != nil { return info.BannerCallback(banner) } return nil } clientConfig := &ssh.ClientConfig{ User: info.User, Auth: auth, Timeout: authTimeout, HostKeyCallback: hostKeyCallback, BannerCallback: bannerCallback, } if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 { clientConfig.Config = ssh.Config{ Ciphers: info.Ciphers, MACs: info.MACs, KeyExchanges: info.KeyExchanges, } } targetAddr := joinHostPort(info.Addr, info.Port) rawConn, upstream, err := dialTargetConn(loginCtx, info) if err != nil { return sshInfo, err } restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout) defer restoreDeadline() clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) if err != nil { _ = rawConn.Close() if upstream != nil { _ = upstream.Close() } return sshInfo, err } client := ssh.NewClient(clientConn, chans, reqs) sshInfo.setTransport(client, upstream) if sshInfo.PublicKey != nil { sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal()) } sshInfo.startAutoKeepAlive() return sshInfo, nil } func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { if ctx == nil { ctx = context.Background() } if timeout <= 0 { return ctx, func() {} } return context.WithTimeout(ctx, timeout) } func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) { info := LoginInput{ Addr: host, Port: port, Timeout: timeout, DialTimeout: timeout, User: user, HostKeyCallback: DefaultAllowHostKeyCallback, } if prikeyPath != "" { prikey, err := os.ReadFile(prikeyPath) if err != nil { return nil, err } info.Prikey = string(prikey) if passwd != "" { info.PrikeyPwd = passwd } } else { info.Password = passwd } return Login(info) } func normalizeLoginInput(info LoginInput) LoginInput { if info.Port <= 0 { info.Port = defaultSSHPort } return info } func effectiveLoginTimeout(info LoginInput) time.Duration { if info.Timeout <= 0 { return 0 } return info.Timeout } func effectiveDialTimeout(info LoginInput) time.Duration { switch { case info.DialTimeout < 0: return 0 case info.DialTimeout > 0: return info.DialTimeout case info.Timeout > 0: return info.Timeout default: return defaultLoginTimeout } } func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { order, err := normalizeAuthOrder(info.AuthOrder) if err != nil { return nil, nil, err } auth := make([]ssh.AuthMethod, 0, len(order)) var agentErr error var cleanupFuncs []func() for _, methodKind := range order { switch methodKind { case AuthMethodPrivateKey: method, err := buildPrivateKeyAuthMethod(info) if err != nil { return nil, nil, err } if method != nil { auth = append(auth, method) } case AuthMethodPassword: method := buildPasswordAuthMethod(info.Password, info.PasswordCallback) if method != nil { auth = append(auth, method) } case AuthMethodKeyboardInteractive: method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback) if method != nil { auth = append(auth, method) } case AuthMethodSSHAgent: if info.DisableSSHAgent { continue } agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info)) if err != nil { agentErr = err continue } if agentMethod != nil { auth = append(auth, agentMethod) } if cleanup != nil { cleanupFuncs = append(cleanupFuncs, cleanup) } } } if len(auth) == 0 { if agentErr != nil { return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr) } return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required") } return auth, composeCleanup(cleanupFuncs...), nil } func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) { if len(order) == 0 { return append([]AuthMethodKind(nil), defaultAuthOrder...), nil } normalized := make([]AuthMethodKind, 0, len(order)) seen := make(map[AuthMethodKind]struct{}, len(order)) for _, raw := range order { kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw)))) if kind == "" { return nil, errors.New("auth order contains an empty auth method") } if !isSupportedAuthMethodKind(kind) { return nil, fmt.Errorf("unsupported auth method %q", raw) } if _, exists := seen[kind]; exists { continue } seen[kind] = struct{}{} normalized = append(normalized, kind) } if len(normalized) == 0 { return nil, errors.New("auth order is empty") } return normalized, nil } func isSupportedAuthMethodKind(kind AuthMethodKind) bool { switch kind { case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent: return true default: return false } } func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) { if strings.TrimSpace(info.Prikey) == "" { return nil, nil } pemBytes := []byte(info.Prikey) if info.PrikeyPwd == "" { signer, err := ssh.ParsePrivateKey(pemBytes) if err != nil { return nil, err } return ssh.PublicKeys(signer), nil } signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) if err != nil { return nil, err } return ssh.PublicKeys(signer), nil } func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod { if password != "" { return ssh.Password(password) } if callback != nil { return ssh.PasswordCallback(callback) } return nil } func buildKeyboardInteractiveAuthMethod( password string, passwordCallback func() (string, error), challenge ssh.KeyboardInteractiveChallenge, ) ssh.AuthMethod { if challenge != nil { return ssh.KeyboardInteractive(challenge) } if password == "" && passwordCallback == nil { return nil } keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) { if len(questions) == 0 { return []string{}, nil } answer := password if answer == "" { var err error answer, err = passwordCallback() if err != nil { return nil, err } } answers := make([]string, len(questions)) for i := range questions { answers[i] = answer } return answers, nil } return ssh.KeyboardInteractive(keyboardInteractiveChallenge) } func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) { conn, err := dialSSHAgent(timeout) if err != nil { if errors.Is(err, errSSHAgentUnavailable) { return nil, nil, nil } return nil, nil, err } if conn == nil { return nil, nil, nil } signers, err := agent.NewClient(conn).Signers() if err != nil { _ = conn.Close() return nil, nil, err } if len(signers) == 0 { _ = conn.Close() return nil, nil, errors.New("ssh-agent has no loaded keys") } return ssh.PublicKeys(signers...), func() { _ = conn.Close() }, nil } func composeCleanup(funcs ...func()) func() { if len(funcs) == 0 { return nil } return func() { for i := len(funcs) - 1; i >= 0; i-- { if funcs[i] != nil { funcs[i]() } } } }