363 lines
8.5 KiB
Go
363 lines
8.5 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/base64"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
"golang.org/x/crypto/ssh/agent"
|
||
|
|
)
|
||
|
|
|
||
|
|
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||
|
|
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||
|
|
|
||
|
|
var defaultAuthOrder = []AuthMethodKind{
|
||
|
|
AuthMethodSSHAgent,
|
||
|
|
AuthMethodPrivateKey,
|
||
|
|
AuthMethodPassword,
|
||
|
|
AuthMethodKeyboardInteractive,
|
||
|
|
}
|
||
|
|
|
||
|
|
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||
|
|
return loginWithContext(ctx, info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func Login(info LoginInput) (*StarSSH, error) {
|
||
|
|
return LoginContext(context.Background(), info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||
|
|
info = normalizeLoginInput(info)
|
||
|
|
if info.HostKeyCallback == nil {
|
||
|
|
return nil, ErrHostKeyCallbackRequired
|
||
|
|
}
|
||
|
|
|
||
|
|
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
sshInfo := &StarSSH{
|
||
|
|
LoginInfo: info,
|
||
|
|
}
|
||
|
|
|
||
|
|
auth, authCleanup, err := buildAuthMethods(info)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if authCleanup != nil {
|
||
|
|
defer authCleanup()
|
||
|
|
}
|
||
|
|
|
||
|
|
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||
|
|
sshInfo.PublicKey = key
|
||
|
|
sshInfo.RemoteAddr = remote
|
||
|
|
sshInfo.Hostname = hostname
|
||
|
|
|
||
|
|
return info.HostKeyCallback(hostname, remote, key)
|
||
|
|
}
|
||
|
|
|
||
|
|
bannerCallback := func(banner string) error {
|
||
|
|
sshInfo.Banner = banner
|
||
|
|
if info.BannerCallback != nil {
|
||
|
|
return info.BannerCallback(banner)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
clientConfig := &ssh.ClientConfig{
|
||
|
|
User: info.User,
|
||
|
|
Auth: auth,
|
||
|
|
Timeout: info.Timeout,
|
||
|
|
HostKeyCallback: hostKeyCallback,
|
||
|
|
BannerCallback: bannerCallback,
|
||
|
|
}
|
||
|
|
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
|
||
|
|
clientConfig.Config = ssh.Config{
|
||
|
|
Ciphers: info.Ciphers,
|
||
|
|
MACs: info.MACs,
|
||
|
|
KeyExchanges: info.KeyExchanges,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
targetAddr := joinHostPort(info.Addr, info.Port)
|
||
|
|
rawConn, upstream, err := dialTargetConn(loginCtx, info)
|
||
|
|
if err != nil {
|
||
|
|
return sshInfo, err
|
||
|
|
}
|
||
|
|
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
||
|
|
defer restoreDeadline()
|
||
|
|
|
||
|
|
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||
|
|
if err != nil {
|
||
|
|
_ = rawConn.Close()
|
||
|
|
if upstream != nil {
|
||
|
|
_ = upstream.Close()
|
||
|
|
}
|
||
|
|
return sshInfo, err
|
||
|
|
}
|
||
|
|
client := ssh.NewClient(clientConn, chans, reqs)
|
||
|
|
|
||
|
|
sshInfo.setTransport(client, upstream)
|
||
|
|
if sshInfo.PublicKey != nil {
|
||
|
|
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
|
||
|
|
}
|
||
|
|
sshInfo.startAutoKeepAlive()
|
||
|
|
|
||
|
|
return sshInfo, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
if timeout <= 0 {
|
||
|
|
return ctx, func() {}
|
||
|
|
}
|
||
|
|
return context.WithTimeout(ctx, timeout)
|
||
|
|
}
|
||
|
|
|
||
|
|
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
|
||
|
|
info := LoginInput{
|
||
|
|
Addr: host,
|
||
|
|
Port: port,
|
||
|
|
Timeout: timeout,
|
||
|
|
User: user,
|
||
|
|
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||
|
|
}
|
||
|
|
|
||
|
|
if prikeyPath != "" {
|
||
|
|
prikey, err := os.ReadFile(prikeyPath)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
info.Prikey = string(prikey)
|
||
|
|
if passwd != "" {
|
||
|
|
info.PrikeyPwd = passwd
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
info.Password = passwd
|
||
|
|
}
|
||
|
|
|
||
|
|
return Login(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeLoginInput(info LoginInput) LoginInput {
|
||
|
|
if info.Port <= 0 {
|
||
|
|
info.Port = defaultSSHPort
|
||
|
|
}
|
||
|
|
if info.Timeout <= 0 {
|
||
|
|
info.Timeout = defaultLoginTimeout
|
||
|
|
}
|
||
|
|
return info
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||
|
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
auth := make([]ssh.AuthMethod, 0, len(order))
|
||
|
|
var agentErr error
|
||
|
|
var cleanupFuncs []func()
|
||
|
|
|
||
|
|
for _, methodKind := range order {
|
||
|
|
switch methodKind {
|
||
|
|
case AuthMethodPrivateKey:
|
||
|
|
method, err := buildPrivateKeyAuthMethod(info)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
if method != nil {
|
||
|
|
auth = append(auth, method)
|
||
|
|
}
|
||
|
|
case AuthMethodPassword:
|
||
|
|
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
|
||
|
|
if method != nil {
|
||
|
|
auth = append(auth, method)
|
||
|
|
}
|
||
|
|
case AuthMethodKeyboardInteractive:
|
||
|
|
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
|
||
|
|
if method != nil {
|
||
|
|
auth = append(auth, method)
|
||
|
|
}
|
||
|
|
case AuthMethodSSHAgent:
|
||
|
|
if info.DisableSSHAgent {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
||
|
|
if err != nil {
|
||
|
|
agentErr = err
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if agentMethod != nil {
|
||
|
|
auth = append(auth, agentMethod)
|
||
|
|
}
|
||
|
|
if cleanup != nil {
|
||
|
|
cleanupFuncs = append(cleanupFuncs, cleanup)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(auth) == 0 {
|
||
|
|
if agentErr != nil {
|
||
|
|
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
|
||
|
|
}
|
||
|
|
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
|
||
|
|
}
|
||
|
|
|
||
|
|
return auth, composeCleanup(cleanupFuncs...), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
|
||
|
|
if len(order) == 0 {
|
||
|
|
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
normalized := make([]AuthMethodKind, 0, len(order))
|
||
|
|
seen := make(map[AuthMethodKind]struct{}, len(order))
|
||
|
|
for _, raw := range order {
|
||
|
|
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
|
||
|
|
if kind == "" {
|
||
|
|
return nil, errors.New("auth order contains an empty auth method")
|
||
|
|
}
|
||
|
|
if !isSupportedAuthMethodKind(kind) {
|
||
|
|
return nil, fmt.Errorf("unsupported auth method %q", raw)
|
||
|
|
}
|
||
|
|
if _, exists := seen[kind]; exists {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
seen[kind] = struct{}{}
|
||
|
|
normalized = append(normalized, kind)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(normalized) == 0 {
|
||
|
|
return nil, errors.New("auth order is empty")
|
||
|
|
}
|
||
|
|
return normalized, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
|
||
|
|
switch kind {
|
||
|
|
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
|
||
|
|
return true
|
||
|
|
default:
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
|
||
|
|
if strings.TrimSpace(info.Prikey) == "" {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
pemBytes := []byte(info.Prikey)
|
||
|
|
if info.PrikeyPwd == "" {
|
||
|
|
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return ssh.PublicKeys(signer), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return ssh.PublicKeys(signer), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
|
||
|
|
if password != "" {
|
||
|
|
return ssh.Password(password)
|
||
|
|
}
|
||
|
|
if callback != nil {
|
||
|
|
return ssh.PasswordCallback(callback)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildKeyboardInteractiveAuthMethod(
|
||
|
|
password string,
|
||
|
|
passwordCallback func() (string, error),
|
||
|
|
challenge ssh.KeyboardInteractiveChallenge,
|
||
|
|
) ssh.AuthMethod {
|
||
|
|
if challenge != nil {
|
||
|
|
return ssh.KeyboardInteractive(challenge)
|
||
|
|
}
|
||
|
|
if password == "" && passwordCallback == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||
|
|
if len(questions) == 0 {
|
||
|
|
return []string{}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
answer := password
|
||
|
|
if answer == "" {
|
||
|
|
var err error
|
||
|
|
answer, err = passwordCallback()
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
answers := make([]string, len(questions))
|
||
|
|
for i := range questions {
|
||
|
|
answers[i] = answer
|
||
|
|
}
|
||
|
|
return answers, nil
|
||
|
|
}
|
||
|
|
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||
|
|
conn, err := dialSSHAgent(timeout)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, errSSHAgentUnavailable) {
|
||
|
|
return nil, nil, nil
|
||
|
|
}
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
if conn == nil {
|
||
|
|
return nil, nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
signers, err := agent.NewClient(conn).Signers()
|
||
|
|
if err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
if len(signers) == 0 {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, nil, errors.New("ssh-agent has no loaded keys")
|
||
|
|
}
|
||
|
|
|
||
|
|
return ssh.PublicKeys(signers...), func() {
|
||
|
|
_ = conn.Close()
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func composeCleanup(funcs ...func()) func() {
|
||
|
|
if len(funcs) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return func() {
|
||
|
|
for i := len(funcs) - 1; i >= 0; i-- {
|
||
|
|
if funcs[i] != nil {
|
||
|
|
funcs[i]()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|