package starssh import ( "context" "encoding/base64" "errors" "net" "os" "time" "golang.org/x/crypto/ssh" ) var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil } func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) { return loginWithContext(ctx, info) } func Login(info LoginInput) (*StarSSH, error) { return LoginContext(context.Background(), info) } func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { info = normalizeLoginInput(info) if info.HostKeyCallback == nil { return nil, ErrHostKeyCallbackRequired } authTimeout := effectiveLoginTimeout(info) loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout) defer cancel() order, err := normalizeAuthOrder(info.AuthOrder) if err != nil { return nil, err } if shouldRetrySSHAgentAuth(info, order) { agentAttempt := newSSHAgentAuthAttempt() for { agentAttempt.begin() sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt) if err == nil { return sshInfo, nil } if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil { continue } return sshInfo, err } } return loginOnceWithContext(loginCtx, info, authTimeout, nil) } func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) { sshInfo := &StarSSH{ LoginInfo: info, } auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt) if err != nil { return nil, err } if authCleanup != nil { defer authCleanup() } hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { sshInfo.PublicKey = key sshInfo.RemoteAddr = remote sshInfo.Hostname = hostname return info.HostKeyCallback(hostname, remote, key) } bannerCallback := func(banner string) error { sshInfo.Banner = banner if info.BannerCallback != nil { return info.BannerCallback(banner) } return nil } clientConfig := &ssh.ClientConfig{ User: info.User, Auth: auth, Timeout: authTimeout, HostKeyCallback: hostKeyCallback, BannerCallback: bannerCallback, } if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 { clientConfig.Config = ssh.Config{ Ciphers: info.Ciphers, MACs: info.MACs, KeyExchanges: info.KeyExchanges, } } targetAddr := joinHostPort(info.Addr, info.Port) rawConn, upstream, err := dialTargetConn(ctx, info) if err != nil { return sshInfo, err } restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout) defer restoreDeadline() clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) if err != nil { _ = rawConn.Close() if upstream != nil { _ = upstream.Close() } return sshInfo, err } client := ssh.NewClient(clientConn, chans, reqs) sshInfo.setTransport(client, upstream) if sshInfo.PublicKey != nil { sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal()) } sshInfo.startAutoKeepAlive() return sshInfo, nil } func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { if ctx == nil { ctx = context.Background() } if timeout <= 0 { return ctx, func() {} } return context.WithTimeout(ctx, timeout) } func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) { info := LoginInput{ Addr: host, Port: port, Timeout: timeout, DialTimeout: timeout, User: user, HostKeyCallback: DefaultAllowHostKeyCallback, } if prikeyPath != "" { prikey, err := os.ReadFile(prikeyPath) if err != nil { return nil, err } info.Prikey = string(prikey) if passwd != "" { info.PrikeyPwd = passwd } } else { info.Password = passwd } return Login(info) } func normalizeLoginInput(info LoginInput) LoginInput { if info.Port <= 0 { info.Port = defaultSSHPort } return info } func effectiveLoginTimeout(info LoginInput) time.Duration { if info.Timeout <= 0 { return 0 } return info.Timeout } func effectiveDialTimeout(info LoginInput) time.Duration { switch { case info.DialTimeout < 0: return 0 case info.DialTimeout > 0: return info.DialTimeout case info.Timeout > 0: return info.Timeout default: return defaultLoginTimeout } }