6 Commits

Author SHA1 Message Date
b612 672a111ec1 feat(sftp): 增加可选的并发传输策略
- 增加 SFTP client 级配置,支持 packet size、单文件并发请求数、并发读和并发写
- 将吞吐优化限制在 StarSSH 托管的 SFTP client 路径中
- 上传和下载在显式启用时使用并发快路径,同时保留原子传输生命周期
- 避免快路径失败前提前上报 100% 进度
- 补充安全校验、托管 client 配置、进度、取消和下载对齐的回归测试
2026-06-22 03:26:47 +08:00
b612 0c23e7d4bf feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
2026-05-27 13:10:35 +08:00
b612 ad7c8b0587 fix: 重构 ssh agent forwarding 转发代理并修复资源残留
- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥
- 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册
- 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接
- 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期
- 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留
- 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为
- 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
2026-04-27 00:06:32 +08:00
b612 1625997d8f fix: 拆分 starssh 的拨号超时与认证超时语义
- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】
- 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值
- 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑
- 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时
- 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout
- 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题
- 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout
- 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
2026-04-26 23:29:36 +08:00
b612 b29246a9c4 feat: 增强 starssh 的 agent forwarding 与 tcp/unix 转发能力
- 为 LoginInput 增加 ForwardSSHAgent 配置,并在 Exec/PTTY 会话创建时按需自动请求 agent forwarding
- 新增 agent_forward 运行时,封装本地 ssh-agent 建连、转发注册、显式请求与 unavailable/denied 语义
- 自动 agent forwarding 改为 best-effort:本地 agent 不可用、转发被拒绝或初始化失败时不再打断会话创建
- 为 StarSSH 增加 closing 状态与 agent forwarder 生命周期回收,避免 Close 与会话创建并发时泄漏资源
- 扩展 ForwardRequest 为带网络归一化的转发模型,支持 tcp/tcp4/tcp6/unix 端点组合
- 新增本地/远端 tcp<->unix、unix<->unix 及 detached helper,补齐 streamlocal 场景下的常用 API
- 将显式网络地址编码收口为 tcp4://、tcp6://、unix://,消除 tcp:22 一类值的解析歧义
- 为本地 unix listener 增加 stale socket 探测、复用与关闭清理,避免遗留 socket 导致重启失败
- 补充 agent forwarding、关闭竞态、remote unix forward、local unix forward、stale socket 复用与端点解析等回归测试
2026-04-26 20:27:10 +08:00
b612 f20eb653ae refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力
- 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现
  - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持
  - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行
  - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令
  - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理
  - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景
  - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径
  - 围绕新的 session/连接生命周期重做执行池与运行时支撑
  - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义
  - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
37 changed files with 13281 additions and 771 deletions
+9
View File
@@ -0,0 +1,9 @@
.sentrux/
agent_readme.md
target.md
.gocache/
.tmp_*/
.codex/
.idea/
agents.md
.codex
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+454
View File
@@ -0,0 +1,454 @@
package starssh
import (
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var requestSSHAgentForwarding = func(session *ssh.Session) error {
return sshagent.RequestAgentForwarding(session)
}
const sshAgentChannelType = "auth-agent@openssh.com"
var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return startSSHAgentForwardProxy(client, timeouts)
}
var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts)
if err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if conn == nil {
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
}
return conn.Close()
}
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
type sshAgentForwardProxy struct {
stopOnce sync.Once
stopCh chan struct{}
activeMu sync.Mutex
active map[*sshAgentForwardBridge]struct{}
}
func (p *sshAgentForwardProxy) Close() error {
if p == nil {
return nil
}
p.stopOnce.Do(func() {
close(p.stopCh)
})
p.closeActive()
return nil
}
type sshAgentForwardBridge struct {
proxy *sshAgentForwardProxy
channel ssh.Channel
conn net.Conn
idleTimeout time.Duration
closeOnce sync.Once
signalOnce sync.Once
done chan struct{}
activity chan struct{}
}
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
if s == nil {
return errors.New("ssh client is nil")
}
if session == nil {
return errors.New("ssh session is nil")
}
if err := s.ensureAgentForwarding(); err != nil {
return err
}
if err := requestSSHAgentForwarding(session); err != nil {
if isSSHAgentForwardingDeniedError(err) {
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
}
return err
}
return nil
}
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
if s == nil || !s.LoginInfo.ForwardSSHAgent {
return nil
}
err := s.RequestAgentForwarding(session)
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
return nil
}
return err
}
func (s *StarSSH) ensureAgentForwarding() error {
if s == nil {
return errors.New("ssh client is nil")
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
if s.agentForwarder != nil {
return nil
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
timeouts := effectiveSSHAgentTimeouts(s.LoginInfo)
if err := probeSSHAgentForwarding(timeouts); err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if s.closing.Load() {
return errSSHClientClosing
}
closer, err := routeSSHAgentForwarding(client, timeouts)
if err != nil {
return err
}
if !s.canAttachAgentForwarder(client) {
if closer != nil {
_ = closer.Close()
}
return errSSHClientClosing
}
s.agentForwarder = closer
return nil
}
func (s *StarSSH) takeAgentForwarder() io.Closer {
if s == nil {
return nil
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
closer := s.agentForwarder
s.agentForwarder = nil
return closer
}
func isSSHAgentForwardingDeniedError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingDenied) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "forwarding request denied") ||
strings.Contains(message, "agent forwarding disabled")
}
func isSSHAgentForwardingUnavailableError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
strings.Contains(message, "ssh-agent unavailable")
}
func wrapSSHAgentForwardingUnavailable(err error) error {
if err == nil {
return nil
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return err
}
if errors.Is(err, errSSHAgentUnavailable) {
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
}
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
}
func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
channels := client.HandleChannelOpen(sshAgentChannelType)
if channels == nil {
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
go func() {
for {
select {
case <-proxy.stopCh:
return
case ch, ok := <-channels:
if !ok {
return
}
go handleSSHAgentForwardChannel(proxy, ch, timeouts)
}
}
}()
return proxy, nil
}
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) {
if ch == nil {
return
}
conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts)
if err != nil {
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
return
}
if conn == nil {
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
return
}
channel, reqs, err := ch.Accept()
if err != nil {
_ = conn.Close()
return
}
go ssh.DiscardRequests(reqs)
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: conn,
idleTimeout: timeouts.Forward,
}
if !proxy.registerBridge(bridge) {
bridge.close()
return
}
go bridge.run()
}
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
bridge := &sshAgentForwardBridge{
channel: channel,
conn: conn,
}
bridge.run()
}
func (b *sshAgentForwardBridge) run() {
if b == nil {
return
}
b.ensureSignals()
stopWatchdog := b.startIdleWatchdog()
defer stopWatchdog()
defer b.unregister()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(
sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch},
sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch},
)
b.close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(
sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch},
sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch},
)
b.close()
}()
wg.Wait()
}
func (b *sshAgentForwardBridge) close() {
if b == nil {
return
}
b.closeOnce.Do(func() {
b.ensureSignals()
close(b.done)
closeWriter(b.channel)
closeWriter(b.conn)
if b.channel != nil {
_ = b.channel.Close()
}
if b.conn != nil {
_ = b.conn.Close()
}
})
}
func (b *sshAgentForwardBridge) ensureSignals() {
if b == nil {
return
}
b.signalOnce.Do(func() {
b.done = make(chan struct{})
b.activity = make(chan struct{}, 1)
})
}
func (b *sshAgentForwardBridge) startIdleWatchdog() func() {
if b == nil || b.idleTimeout <= 0 {
return func() {}
}
b.ensureSignals()
timer := time.NewTimer(b.idleTimeout)
stopped := make(chan struct{})
go func() {
defer timer.Stop()
for {
select {
case <-timer.C:
b.close()
return
case <-b.activity:
resetTimer(timer, b.idleTimeout)
case <-b.done:
return
case <-stopped:
return
}
}
}()
return func() {
close(stopped)
}
}
func (b *sshAgentForwardBridge) touch() {
if b == nil || b.idleTimeout <= 0 || b.activity == nil {
return
}
select {
case b.activity <- struct{}{}:
default:
}
}
type sshAgentForwardActivityReader struct {
io.Reader
touch func()
}
func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) {
n, err := r.Reader.Read(p)
if n > 0 && r.touch != nil {
r.touch()
}
return n, err
}
type sshAgentForwardActivityWriter struct {
io.Writer
touch func()
}
func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) {
n, err := w.Writer.Write(p)
if n > 0 && w.touch != nil {
w.touch()
}
return n, err
}
func resetTimer(timer *time.Timer, timeout time.Duration) {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(timeout)
}
func (b *sshAgentForwardBridge) unregister() {
if b == nil || b.proxy == nil {
return
}
b.proxy.unregisterBridge(b)
}
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
if p == nil || bridge == nil {
return false
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
select {
case <-p.stopCh:
return false
default:
}
if p.active == nil {
p.active = make(map[*sshAgentForwardBridge]struct{})
}
p.active[bridge] = struct{}{}
return true
}
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
if p == nil || bridge == nil {
return
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
delete(p.active, bridge)
}
func (p *sshAgentForwardProxy) closeActive() {
if p == nil {
return
}
p.activeMu.Lock()
active := make([]*sshAgentForwardBridge, 0, len(p.active))
for bridge := range p.active {
active = append(active, bridge)
}
p.active = make(map[*sshAgentForwardBridge]struct{})
p.activeMu.Unlock()
for _, bridge := range active {
bridge.close()
}
}
func closeWriter(value any) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := value.(closeWriter); ok {
_ = cw.CloseWrite()
}
}
+665
View File
@@ -0,0 +1,665 @@
package starssh
import (
"bytes"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type testCloser struct {
closed atomic.Int32
}
func (c *testCloser) Close() error {
c.closed.Add(1)
return nil
}
type trackedConn struct {
net.Conn
closed atomic.Int32
}
func (c *trackedConn) Close() error {
c.closed.Add(1)
if c.Conn == nil {
return nil
}
return c.Conn.Close()
}
type testSSHChannel struct {
readFunc func([]byte) (int, error)
stderr bytes.Buffer
closed atomic.Int32
closeOnce sync.Once
closeCh chan struct{}
}
type testNewChannel struct {
channel ssh.Channel
accepted atomic.Bool
rejected atomic.Bool
}
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
c.accepted.Store(true)
requests := make(chan *ssh.Request)
close(requests)
return c.channel, requests, nil
}
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
c.rejected.Store(true)
return nil
}
func (c *testNewChannel) ChannelType() string {
return sshAgentChannelType
}
func (c *testNewChannel) ExtraData() []byte {
return nil
}
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
return &testSSHChannel{
readFunc: readFunc,
closeCh: make(chan struct{}),
}
}
func newBlockingTestSSHChannel() *testSSHChannel {
ch := newTestSSHChannel(nil)
ch.readFunc = func(p []byte) (int, error) {
<-ch.closeCh
return 0, io.EOF
}
return ch
}
func (c *testSSHChannel) Read(p []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
if c.readFunc != nil {
return c.readFunc(p)
}
return 0, io.EOF
}
func (c *testSSHChannel) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testSSHChannel) Close() error {
if c == nil {
return nil
}
c.closeOnce.Do(func() {
c.closed.Add(1)
close(c.closeCh)
})
return nil
}
func (c *testSSHChannel) CloseWrite() error {
return nil
}
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return false, nil
}
func (c *testSSHChannel) Stderr() io.ReadWriter {
return &c.stderr
}
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
Timeout: time.Second,
SSHAgentTimeout: 3 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
},
}
star.setTransport(baseClient, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
return &ssh.Session{}, nil
}
var probeCalls atomic.Int32
closer := &testCloser{}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
probeCalls.Add(1)
if timeouts.Dial != time.Second {
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
}
return nil
}
var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
routeCalls.Add(1)
if client != baseClient {
t.Fatalf("unexpected routed client %p", client)
}
if timeouts.Dial != time.Second {
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
}
return closer, nil
}
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
if session == nil {
t.Fatal("expected non-nil ssh session")
}
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("first exec session: %v", err)
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("second exec session: %v", err)
}
if probeCalls.Load() != 1 {
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
}
if routeCalls.Load() != 1 {
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
}
if requestCalls.Load() != 2 {
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
}
closeSSHClient = func(client sshClientRequester) error { return nil }
if err := star.Close(); err != nil {
t.Fatalf("close starssh: %v", err)
}
if closer.closed.Load() != 1 {
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
}
}
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
oldNewSSHSession := newSSHSession
oldRequestSessionPTY := requestSessionPTY
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
requestSessionPTY = oldRequestSessionPTY
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
var ptyCalls atomic.Int32
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
ptyCalls.Add(1)
return nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
return nil
}
if _, err := star.NewPTYSession(nil); err != nil {
t.Fatalf("new pty session: %v", err)
}
if ptyCalls.Load() != 1 {
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
}
if requestCalls.Load() != 1 {
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
}
}
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
t.Fatal("agent forwarding probe should not run when disabled")
return nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("agent forwarding should not be requested when disabled")
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session without forwarding: %v", err)
}
}
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("session request should not run when agent forwarder init fails")
return nil
}
err := star.RequestAgentForwarding(&ssh.Session{})
if err == nil {
t.Fatal("expected agent forwarding init error")
}
}
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingUnavailableError(err) {
t.Fatalf("expected unavailable error, got %v", err)
}
}
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingDeniedError(err) {
t.Fatalf("expected forwarding denied error, got %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore agent setup error: %v", err)
}
}
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
started := make(chan struct{})
release := make(chan struct{})
closer := &testCloser{}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
close(started)
<-release
return closer, nil
}
closeSSHClient = func(client sshClientRequester) error { return nil }
errCh := make(chan error, 1)
go func() {
errCh <- star.ensureAgentForwarding()
}()
<-started
closeDone := make(chan struct{})
go func() {
_ = star.Close()
close(closeDone)
}()
deadline := time.Now().Add(time.Second)
for !star.closing.Load() {
if time.Now().After(deadline) {
t.Fatal("close did not enter closing state in time")
}
time.Sleep(time.Millisecond)
}
close(release)
err := <-errCh
if !errors.Is(err, errSSHClientClosing) {
t.Fatalf("expected closing error, got %v", err)
}
<-closeDone
if closer.closed.Load() != 1 {
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
}
if got := star.takeAgentForwarder(); got != nil {
t.Fatal("expected no leaked agent forwarder after close race")
}
}
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newTestSSHChannel(func(p []byte) (int, error) {
return 0, io.EOF
})
done := make(chan struct{})
go func() {
proxySSHAgentChannel(channel, tracked)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected local agent connection to be closed")
}
if channel.closed.Load() == 0 {
t.Fatal("expected ssh channel to be closed")
}
}
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newBlockingTestSSHChannel()
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: tracked,
}
if !proxy.registerBridge(bridge) {
t.Fatal("expected bridge registration to succeed")
}
done := make(chan struct{})
go func() {
bridge.run()
close(done)
}()
if err := proxy.Close(); err != nil {
t.Fatalf("close proxy: %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("bridge did not exit after proxy close")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected proxy close to close local agent connection")
}
if channel.closed.Load() == 0 {
t.Fatal("expected proxy close to close ssh channel")
}
}
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return tracked, nil
}
channel := newBlockingTestSSHChannel()
newChannel := &testNewChannel{
channel: channel,
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
Endpoint: "/tmp/agent.sock",
Forward: 20 * time.Millisecond,
})
if !newChannel.accepted.Load() {
t.Fatal("expected channel to be accepted")
}
waitUntil(t, time.Second, func() bool {
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
}, "forwarded agent bridge did not close both sides after idle timeout")
waitUntil(t, time.Second, func() bool {
proxy.activeMu.Lock()
defer proxy.activeMu.Unlock()
return len(proxy.active) == 0
}, "forwarded agent bridge did not unregister after idle timeout")
}
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(time.Millisecond)
}
t.Fatal(message)
}
+172
View File
@@ -0,0 +1,172 @@
package starssh
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContext(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after PingContext cancel")
}
}
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContextCloseOnCancel(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
}
}
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
}
}
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
}
}
+1218
View File
File diff suppressed because it is too large Load Diff
+69
View File
@@ -0,0 +1,69 @@
package starssh
import (
"context"
"strings"
"testing"
)
func TestExistsFallsBackFromPOSIXToPowerShell(t *testing.T) {
oldExecRequestRunner := execRequestRunner
t.Cleanup(func() {
execRequestRunner = oldExecRequestRunner
})
var calls []ExecRequest
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
calls = append(calls, req)
switch len(calls) {
case 1:
return &ExecResult{ExitCode: 127, Stderr: []byte("sh not found")}, nil
case 2:
return &ExecResult{Stdout: []byte("1\n")}, nil
default:
t.Fatalf("unexpected extra probe request: %+v", req)
return nil, nil
}
}
star := &StarSSH{}
if !star.Exists(`C:\Windows\System32`) {
t.Fatal("expected helper to succeed after PowerShell fallback")
}
if len(calls) != 2 {
t.Fatalf("expected two probe attempts, got %d", len(calls))
}
if calls[0].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(calls[0].Command, "sh -lc ") {
t.Fatalf("unexpected first probe request: %+v", calls[0])
}
if calls[1].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(strings.ToLower(calls[1].Command), "powershell.exe ") {
t.Fatalf("unexpected second probe request: %+v", calls[1])
}
}
func TestGetUserFallsBackToCMDWhenPowerShellVariantsFail(t *testing.T) {
oldExecRequestRunner := execRequestRunner
t.Cleanup(func() {
execRequestRunner = oldExecRequestRunner
})
var calls []ExecRequest
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
calls = append(calls, req)
if len(calls) < 5 {
return &ExecResult{ExitCode: 127, Stderr: []byte("command not found")}, nil
}
return &ExecResult{Stdout: []byte("HOST\\tester\r\n")}, nil
}
star := &StarSSH{}
if got := star.GetUser(); got != `HOST\tester` {
t.Fatalf("unexpected user after fallback: %q", got)
}
if len(calls) != 5 {
t.Fatalf("expected five probe attempts, got %d", len(calls))
}
if got := strings.ToLower(calls[4].Command); got != "cmd.exe /q /d /c whoami" {
t.Fatalf("unexpected final fallback command: %q", calls[4].Command)
}
}
+952
View File
@@ -0,0 +1,952 @@
package starssh
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
type ForwardRequest struct {
// Keep the exported shape compatible with older positional literals:
// ForwardRequest{listenAddr, targetAddr, dialContext}.
//
// Non-default networks can be encoded with an explicit scheme-like prefix:
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
// Bare values default to the "tcp" network.
ListenAddr string
TargetAddr string
DialContext DialContextFunc
}
type normalizedForwardRequest struct {
ListenNetwork string
ListenAddr string
TargetNetwork string
TargetAddr string
DialContext DialContextFunc
}
type DynamicForwardRequest struct {
ListenAddr string
}
type PortForwarder struct {
listener net.Listener
ctx context.Context
cancel context.CancelFunc
acceptDone chan struct{}
connWG sync.WaitGroup
closeOnce sync.Once
connMu sync.Mutex
conns map[net.Conn]struct{}
errMu sync.Mutex
err error
cleanupOnce sync.Once
cleanupFns []func() error
}
const unixForwardProbeTimeout = 200 * time.Millisecond
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
return client.Dial(network, address)
}
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
return client.Listen(network, address)
}
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
if ctx == nil {
ctx = context.Background()
}
return LoginContext(ctx, input)
}
func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) {
return s.DialTCPContext(context.Background(), network, address)
}
func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) {
return s.dialTCPContext(ctx, network, address, nil)
}
func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) {
return s.dialTCPContext(ctx, network, address, s.Close)
}
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
if strings.TrimSpace(network) == "" {
network = "tcp"
}
if strings.TrimSpace(address) == "" {
return nil, errors.New("forward address is empty")
}
type dialResult struct {
conn net.Conn
err error
}
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
runCancel := func() {}
if onCancel != nil {
var cancelOnce sync.Once
runCancel = func() {
cancelOnce.Do(func() {
_ = onCancel()
})
}
cancelDone := make(chan struct{})
defer close(cancelDone)
go func() {
select {
case <-ctx.Done():
runCancel()
case <-cancelDone:
}
}()
}
dialFunc := dialSSHClient
resultCh := make(chan dialResult, 1)
go func() {
conn, err := dialFunc(ctx, client, network, address)
if ctx.Err() != nil && conn != nil {
_ = conn.Close()
conn = nil
}
select {
case resultCh <- dialResult{conn: conn, err: err}:
default:
if conn != nil {
_ = conn.Close()
}
}
}()
select {
case result := <-resultCh:
return result.conn, result.err
case <-ctx.Done():
runCancel()
return nil, ctx.Err()
}
}
func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
if cleanup != nil {
_ = cleanup()
}
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
dialContext := normalizedReq.DialContext
if dialContext == nil {
dialer := &net.Dialer{
Timeout: defaultLoginTimeout,
}
dialContext = dialer.DialContext
}
forwarder := newPortForwarder(listener)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("dynamic forward listen address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
return s.DialTCPContext(ctx, "tcp", targetAddr)
})
return forwarder, nil
}
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
normalized := normalizedForwardRequest{
DialContext: req.DialContext,
}
var err error
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
if err != nil {
return normalized, fmt.Errorf("normalize listen address: %w", err)
}
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
if err != nil {
return normalized, fmt.Errorf("normalize target address: %w", err)
}
if strings.TrimSpace(normalized.ListenAddr) == "" {
return normalized, errors.New("forward listen address is empty")
}
if strings.TrimSpace(normalized.TargetAddr) == "" {
return normalized, errors.New("forward target address is empty")
}
return normalized, nil
}
func normalizeForwardNetwork(network string) string {
network = strings.ToLower(strings.TrimSpace(network))
if network == "" {
return "tcp"
}
return network
}
func isSupportedForwardNetwork(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6", "unix":
return true
default:
return false
}
}
func parseForwardEndpoint(value string) (network string, address string, err error) {
value = strings.TrimSpace(value)
if value == "" {
return "tcp", "", nil
}
lowerValue := strings.ToLower(value)
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
if strings.HasPrefix(lowerValue, prefix) {
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
address = value[len(prefix):]
if !isSupportedForwardNetwork(network) {
return "", "", fmt.Errorf("unsupported forward network %q", network)
}
return network, address, nil
}
}
return "tcp", value, nil
}
func forwardEndpoint(network string, address string) string {
network = normalizeForwardNetwork(network)
if network == "tcp" {
return address
}
return network + "://" + address
}
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("dynamic forward listen address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, "tcp", targetAddr)
})
return forwarder, nil
}
func (f *PortForwarder) Addr() net.Addr {
if f == nil || f.listener == nil {
return nil
}
return f.listener.Addr()
}
func (f *PortForwarder) Wait() error {
if f == nil {
return nil
}
<-f.acceptDone
f.connWG.Wait()
f.runCleanup()
return f.Err()
}
func (f *PortForwarder) Err() error {
if f == nil {
return nil
}
f.errMu.Lock()
defer f.errMu.Unlock()
return f.err
}
func (f *PortForwarder) Close() error {
if f == nil {
return nil
}
var closeErr error
f.closeOnce.Do(func() {
if f.cancel != nil {
f.cancel()
}
if f.listener != nil {
closeErr = normalizeAlreadyClosedError(f.listener.Close())
}
f.closeActiveConnections()
})
<-f.acceptDone
f.connWG.Wait()
f.runCleanup()
if closeErr != nil {
return closeErr
}
return f.Err()
}
func newPortForwarder(listener net.Listener) *PortForwarder {
ctx, cancel := context.WithCancel(context.Background())
return &PortForwarder{
listener: listener,
ctx: ctx,
cancel: cancel,
acceptDone: make(chan struct{}),
conns: make(map[net.Conn]struct{}),
}
}
func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) {
if s == nil {
return nil, errors.New("ssh client is nil")
}
return newDetachedForwardClient(ctx, s.LoginInfo)
}
func (f *PortForwarder) addCleanup(fn func() error) {
if f == nil || fn == nil {
return
}
f.cleanupFns = append(f.cleanupFns, fn)
}
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
network = normalizeForwardNetwork(network)
if network != "unix" {
listener, err := net.Listen(network, address)
return listener, nil, err
}
if err := removeStaleUnixSocket(address); err != nil {
return nil, nil, err
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, nil, err
}
cleanup, err := makeUnixSocketCleanup(address)
if err != nil {
_ = listener.Close()
_ = removeUnixSocketPath(address)
return nil, nil, err
}
return listener, cleanup, nil
}
func removeStaleUnixSocket(path string) error {
info, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if info.Mode()&os.ModeSocket == 0 {
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
}
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
if err == nil {
_ = conn.Close()
return fmt.Errorf("local unix forward path %q is already in use", path)
}
if !isStaleUnixSocketDialError(err) {
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
}
return removeUnixSocketPath(path)
}
func isStaleUnixSocketDialError(err error) bool {
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
}
func makeUnixSocketCleanup(path string) (func() error, error) {
info, err := os.Lstat(path)
if err != nil {
return nil, err
}
return func() error {
current, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
return nil
}
return removeUnixSocketPath(path)
}, nil
}
func removeUnixSocketPath(path string) error {
err := os.Remove(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
func (f *PortForwarder) runCleanup() {
if f == nil {
return
}
f.cleanupOnce.Do(func() {
for _, fn := range f.cleanupFns {
f.setError(normalizeAlreadyClosedError(fn()))
}
})
}
func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) {
go func() {
defer close(f.acceptDone)
for {
conn, err := f.listener.Accept()
if err != nil {
if isClosedListenerError(err) {
return
}
f.setError(err)
return
}
f.trackConn(conn)
f.connWG.Add(1)
go func(src net.Conn) {
defer f.connWG.Done()
defer f.untrackConn(src)
dst, err := targetDial(f.ctx)
if err != nil {
f.setError(err)
_ = src.Close()
return
}
f.trackConn(dst)
defer f.untrackConn(dst)
f.setError(pipeForwardConnections(src, dst))
}(conn)
}
}()
}
func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) {
go func() {
defer close(f.acceptDone)
for {
conn, err := f.listener.Accept()
if err != nil {
if isClosedListenerError(err) {
return
}
f.setError(err)
return
}
f.trackConn(conn)
f.connWG.Add(1)
go func(src net.Conn) {
defer f.connWG.Done()
defer f.untrackConn(src)
if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil {
f.setError(err)
}
}(conn)
}
}()
}
func (f *PortForwarder) setError(err error) {
if f == nil || err == nil || f.shouldIgnoreError(err) {
return
}
f.errMu.Lock()
defer f.errMu.Unlock()
if f.err == nil {
f.err = err
}
}
func pipeForwardConnections(left net.Conn, right net.Conn) error {
if left == nil || right == nil {
if left != nil {
_ = left.Close()
}
if right != nil {
_ = right.Close()
}
return errors.New("forward connection endpoint is nil")
}
defer left.Close()
defer right.Close()
var copyWG sync.WaitGroup
errCh := make(chan error, 2)
copyWG.Add(2)
go func() {
defer copyWG.Done()
_, err := io.Copy(right, left)
errCh <- normalizeAlreadyClosedError(err)
closeWrite(right)
}()
go func() {
defer copyWG.Done()
_, err := io.Copy(left, right)
errCh <- normalizeAlreadyClosedError(err)
closeWrite(left)
}()
copyWG.Wait()
close(errCh)
for err := range errCh {
if err != nil {
return err
}
}
return nil
}
func closeWrite(conn net.Conn) {
if conn == nil {
return
}
type closeWriter interface {
CloseWrite() error
}
if writer, ok := conn.(closeWriter); ok {
_ = writer.CloseWrite()
}
}
func handleDynamicForwardConn(
ctx context.Context,
src net.Conn,
targetDial func(context.Context, string) (net.Conn, error),
trackConn func(net.Conn),
untrackConn func(net.Conn),
) error {
if src == nil {
return errors.New("dynamic forward source connection is nil")
}
defer src.Close()
if err := negotiateSOCKS5NoAuth(src); err != nil {
return err
}
targetAddr, replyCode, err := readSOCKS5ConnectTarget(src)
if err != nil {
_ = writeSOCKS5ServerReply(src, replyCode, nil)
return err
}
dst, err := targetDial(ctx, targetAddr)
if err != nil {
_ = writeSOCKS5ServerReply(src, 0x01, nil)
return err
}
if trackConn != nil {
trackConn(dst)
defer untrackConn(dst)
}
if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil {
_ = dst.Close()
return err
}
return pipeForwardConnections(src, dst)
}
func (f *PortForwarder) trackConn(conn net.Conn) {
if f == nil || conn == nil {
return
}
f.connMu.Lock()
defer f.connMu.Unlock()
f.conns[conn] = struct{}{}
}
func (f *PortForwarder) untrackConn(conn net.Conn) {
if f == nil || conn == nil {
return
}
f.connMu.Lock()
defer f.connMu.Unlock()
delete(f.conns, conn)
}
func (f *PortForwarder) closeActiveConnections() {
if f == nil {
return
}
f.connMu.Lock()
conns := make([]net.Conn, 0, len(f.conns))
for conn := range f.conns {
conns = append(conns, conn)
}
f.connMu.Unlock()
for _, conn := range conns {
_ = conn.Close()
}
}
func (f *PortForwarder) shouldIgnoreError(err error) bool {
if err == nil {
return true
}
if normalizeAlreadyClosedError(err) == nil {
return true
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
return false
}
func negotiateSOCKS5NoAuth(conn net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
if header[0] != 0x05 {
return errors.New("invalid socks5 version")
}
methodCount := int(header[1])
methods := make([]byte, methodCount)
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
method := byte(0xFF)
for _, candidate := range methods {
if candidate == 0x00 {
method = 0x00
break
}
}
if _, err := conn.Write([]byte{0x05, method}); err != nil {
return err
}
if method == 0xFF {
return errors.New("socks5 client does not support no-auth method")
}
return nil
}
func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0x01, err
}
if header[0] != 0x05 {
return "", 0x01, errors.New("invalid socks5 request version")
}
if header[1] != 0x01 {
return "", 0x07, errors.New("unsupported socks5 command")
}
host, err := readSOCKS5RequestHost(conn, header[3])
if err != nil {
return "", 0x08, err
}
portBytes := make([]byte, 2)
if _, err := io.ReadFull(conn, portBytes); err != nil {
return "", 0x01, err
}
port := int(portBytes[0])<<8 | int(portBytes[1])
return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil
}
func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) {
switch addressType {
case 0x01:
buffer := make([]byte, 4)
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return net.IP(buffer).String(), nil
case 0x03:
size := make([]byte, 1)
if _, err := io.ReadFull(conn, size); err != nil {
return "", err
}
buffer := make([]byte, int(size[0]))
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return string(buffer), nil
case 0x04:
buffer := make([]byte, 16)
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return net.IP(buffer).String(), nil
default:
return "", errors.New("unsupported socks5 address type")
}
}
func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error {
reply := []byte{0x05, replyCode, 0x00}
if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil {
if ip4 := tcpAddr.IP.To4(); ip4 != nil {
reply = append(reply, 0x01)
reply = append(reply, ip4...)
} else if ip16 := tcpAddr.IP.To16(); ip16 != nil {
reply = append(reply, 0x04)
reply = append(reply, ip16...)
}
if len(reply) > 3 {
port := tcpAddr.Port
reply = append(reply, byte(port>>8), byte(port))
_, err := conn.Write(reply)
return err
}
}
reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
_, err := conn.Write(reply)
return err
}
func isClosedListenerError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) {
return true
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), "use of closed network connection")
}
+698
View File
@@ -0,0 +1,698 @@
package starssh
import (
"context"
"errors"
"io"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type stubListener struct {
addr net.Addr
acceptCh chan net.Conn
closeCh chan struct{}
closeOnce sync.Once
}
type dialRecord struct {
network string
addr string
}
func newStubListener(addr net.Addr) *stubListener {
return &stubListener{
addr: addr,
acceptCh: make(chan net.Conn, 1),
closeCh: make(chan struct{}),
}
}
func (l *stubListener) Accept() (net.Conn, error) {
select {
case conn, ok := <-l.acceptCh:
if !ok {
return nil, io.EOF
}
return conn, nil
case <-l.closeCh:
return nil, net.ErrClosed
}
}
func (l *stubListener) Close() error {
l.closeOnce.Do(func() {
close(l.closeCh)
close(l.acceptCh)
})
return nil
}
func (l *stubListener) Addr() net.Addr {
return l.addr
}
func (l *stubListener) Push(conn net.Conn) error {
select {
case <-l.closeCh:
return net.ErrClosed
case l.acceptCh <- conn:
return nil
}
}
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return nil, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
closeSSHClient = func(client sshClientRequester) error {
t.Fatal("default local forward should not close the main ssh client")
return nil
}
forwarder, err := star.StartLocalForward(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start local forward: %v", err)
}
defer forwarder.Close()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
if string(reply) != "ping" {
t.Fatalf("unexpected forwarded reply: %q", string(reply))
}
if detachedCalls.Load() != 0 {
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
}
}
func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) {
dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}
req, err := normalizeForwardRequest(ForwardRequest{
"127.0.0.1:10022",
"example.internal:22",
dialer,
})
if err != nil {
t.Fatalf("normalizeForwardRequest: %v", err)
}
if req.ListenNetwork != "tcp" {
t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork)
}
if req.TargetNetwork != "tcp" {
t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork)
}
if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" {
t.Fatalf("unexpected normalized request: %+v", req)
}
if req.DialContext == nil {
t.Fatal("expected DialContext to be preserved")
}
}
func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) {
network, address, err := parseForwardEndpoint("tcp:22")
if err != nil {
t.Fatalf("parseForwardEndpoint: %v", err)
}
if network != "tcp" {
t.Fatalf("network=%q want tcp", network)
}
if address != "tcp:22" {
t.Fatalf("address=%q want tcp:22", address)
}
}
func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) {
network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock")
if err != nil {
t.Fatalf("parseForwardEndpoint unix: %v", err)
}
if network != "unix" || address != "/tmp/test-forward.sock" {
t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address)
}
network, address, err = parseForwardEndpoint("tcp6://[::1]:2222")
if err != nil {
t.Fatalf("parseForwardEndpoint tcp6: %v", err)
}
if network != "tcp6" || address != "[::1]:2222" {
t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address)
}
}
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
detachedClient := &ssh.Client{}
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
star.setTransport(baseClient, nil)
forwardClient := &StarSSH{}
forwardClient.setTransport(detachedClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return forwardClient, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != detachedClient {
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start detached local forward: %v", err)
}
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
if string(reply) != "pong" {
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
}
if err := forwarder.Close(); err != nil {
t.Fatalf("close detached local forward: %v", err)
}
if detachedCalls.Load() != 1 {
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
}
if closeCalls.Load() != 1 {
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != baseClient {
t.Fatal("detached local forward should not detach the main ssh client")
}
if got := forwardClient.snapshotSSHClient(); got != nil {
t.Fatal("detached local forward should close its detached ssh client")
}
}
func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) {
oldListenSSHClient := listenSSHClient
t.Cleanup(func() {
listenSSHClient = oldListenSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
listener := newStubListener(&net.UnixAddr{
Name: "/run/user/0/gnupg/S.gpg-agent",
Net: "unix",
})
var listenedNetwork string
var listenedAddr string
listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
listenedNetwork = network
listenedAddr = address
return listener, nil
}
var targetNetwork string
var targetAddr string
forwarder, err := star.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"),
TargetAddr: "127.0.0.1:4321",
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
},
})
if err != nil {
t.Fatalf("start remote unix forward: %v", err)
}
defer forwarder.Close()
srcPeer, forwardedConn := net.Pipe()
defer srcPeer.Close()
if err := listener.Push(forwardedConn); err != nil {
t.Fatalf("push forwarded connection: %v", err)
}
payload := []byte("unix-forward")
done := make(chan []byte, 1)
go func() {
reply := make([]byte, len(payload))
_, _ = io.ReadFull(srcPeer, reply)
done <- reply
}()
if _, err := srcPeer.Write(payload); err != nil {
t.Fatalf("write source payload: %v", err)
}
select {
case reply := <-done:
if string(reply) != string(payload) {
t.Fatalf("unexpected remote unix forward reply: %q", string(reply))
}
case <-time.After(2 * time.Second):
t.Fatal("remote unix forward did not relay payload")
}
if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" {
t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr)
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var targetNetwork string
var targetAddr string
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "forward.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-local-forward")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix forward payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix forward reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix forward reply: %q", string(reply))
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "cleanup.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("socket should exist while forward is running: %v", err)
}
if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("socket path should be removed on close, got err=%v", err)
}
}
func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "stale.sock")
staleListener, err := net.ListenUnix("unix", &net.UnixAddr{
Name: socketPath,
Net: "unix",
})
if err != nil {
t.Fatalf("create stale unix socket: %v", err)
}
staleListener.SetUnlinkOnClose(false)
if err := staleListener.Close(); err != nil {
t.Fatalf("close stale unix socket listener: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("expected stale unix socket path to remain after close: %v", err)
}
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward on stale socket path: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
reply := make([]byte, len("stale-reuse"))
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial reused unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write([]byte("stale-reuse")); err != nil {
t.Fatalf("write reused unix forward payload: %v", err)
}
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read reused unix forward reply: %v", err)
}
if string(reply) != "stale-reuse" {
t.Fatalf("unexpected reply on reused unix forward: %q", string(reply))
}
}
func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
listenSocketPath := filepath.Join(t.TempDir(), "listen.sock")
forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath)
if err != nil {
t.Fatalf("start local unix-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix-to-unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", listenSocketPath, time.Second)
if err != nil {
t.Fatalf("dial unix-to-unix listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-to-unix")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix-to-unix payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix-to-unix reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != string(payload) {
t.Fatalf("unexpected payload seen by target unix socket: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("target unix socket did not receive forwarded payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath)
if err != nil {
t.Fatalf("start local tcp-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local tcp-to-unix forward: %v", closeErr)
}
}()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix"))
if string(reply) != "tcp-to-unix" {
t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != "tcp-to-unix" {
t.Fatalf("unexpected payload seen by unix target: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("unix target did not receive forwarded tcp payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
func echoForwardPipe(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return
}
_, _ = conn.Write(buf[:n])
}
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
t.Helper()
conn, err := net.DialTimeout("tcp", addr, time.Second)
if err != nil {
t.Fatalf("dial forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write forwarded payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read forwarded reply: %v", err)
}
return reply
}
+12 -3
View File
@@ -1,8 +1,17 @@
module b612.me/starssh
go 1.16
go 1.20
require (
github.com/pkg/sftp v1.13.4
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000
github.com/Microsoft/go-winio v0.6.1
github.com/pkg/sftp v1.13.9
golang.org/x/crypto v0.33.0
golang.org/x/sys v0.30.0
)
require (
github.com/kr/fs v0.1.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
)
+78 -15
View File
@@ -1,29 +1,92 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/pkg/sftp v1.13.4 h1:Lb0RYJCmgUcBgZosfoi9Y9sbl6+LJgOIgk/2Y4YjMFg=
github.com/pkg/sftp v1.13.4/go.mod h1:LzqnAvaD5TWeNBsZpfKxSYn1MbjWwOsCIAFFJbpIsK8=
github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw=
github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+290
View File
@@ -0,0 +1,290 @@
package starssh
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
type HostKeyFingerprintMismatchError struct {
Expected []string
ActualSHA256 string
ActualLegacyMD5 string
}
type AcceptNewHostKeyOptions struct {
KnownHostsFile string
HashHosts bool
IncludeRemoteAddress bool
FileMode os.FileMode
}
func (e *HostKeyFingerprintMismatchError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
}
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
trimmed := make([]string, 0, len(files))
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
trimmed = append(trimmed, file)
}
if len(trimmed) == 0 {
return nil, ErrKnownHostsFileRequired
}
return knownhosts.New(trimmed...)
}
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
KnownHostsFile: file,
IncludeRemoteAddress: true,
})
}
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
options = normalizeAcceptNewHostKeyOptions(options)
if options.KnownHostsFile == "" {
return nil, ErrKnownHostsFileRequired
}
state := &acceptNewHostKeyState{
file: options.KnownHostsFile,
hashHosts: options.HashHosts,
includeRemoteAddress: options.IncludeRemoteAddress,
fileMode: options.FileMode,
}
if err := state.reload(); err != nil {
return nil, err
}
return state.checkHostKey, nil
}
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
normalized := make([]string, 0, len(fingerprints))
seen := make(map[string]struct{}, len(fingerprints))
for _, raw := range fingerprints {
fingerprint, err := normalizeHostKeyFingerprint(raw)
if err != nil {
return nil, err
}
if fingerprint == "" {
continue
}
if _, exists := seen[fingerprint]; exists {
continue
}
seen[fingerprint] = struct{}{}
normalized = append(normalized, fingerprint)
}
if len(normalized) == 0 {
return nil, ErrHostFingerprintRequired
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
actualSHA256 := ssh.FingerprintSHA256(key)
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
for _, want := range normalized {
if want == actualSHA256 || want == actualMD5 {
return nil
}
}
return &HostKeyFingerprintMismatchError{
Expected: append([]string(nil), normalized...),
ActualSHA256: actualSHA256,
ActualLegacyMD5: actualMD5,
}
}, nil
}
type acceptNewHostKeyState struct {
file string
hashHosts bool
includeRemoteAddress bool
fileMode os.FileMode
mu sync.Mutex
cb ssh.HostKeyCallback
}
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cb != nil {
err := s.cb(hostname, remote, key)
if err == nil {
return nil
}
var keyErr *knownhosts.KeyError
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
return err
}
}
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
if err != nil {
return err
}
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
return err
}
if err := s.reload(); err != nil {
return err
}
if s.cb == nil {
return errors.New("known_hosts callback is nil after reload")
}
return s.cb(hostname, remote, key)
}
func (s *acceptNewHostKeyState) reload() error {
callback, err := loadKnownHostsCallback(s.file)
if err != nil {
return err
}
s.cb = callback
return nil
}
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
_, err := os.Stat(file)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, err
}
return knownhosts.New(file)
}
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
if options.FileMode == 0 {
options.FileMode = 0o600
}
return options
}
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
if len(addresses) == 0 {
return "", errors.New("no hostname or remote address available for known_hosts entry")
}
patterns := make([]string, 0, len(addresses))
for _, address := range addresses {
normalized := knownhosts.Normalize(address)
if hashHosts {
patterns = append(patterns, knownhosts.HashHostname(normalized))
continue
}
patterns = append(patterns, normalized)
}
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
return strings.Join(patterns, ",") + " " + authorizedKey, nil
}
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
addresses := make([]string, 0, 2)
seen := make(map[string]struct{}, 2)
add := func(address string) {
address = strings.TrimSpace(address)
if address == "" {
return
}
normalized := knownhosts.Normalize(address)
if _, exists := seen[normalized]; exists {
return
}
seen[normalized] = struct{}{}
addresses = append(addresses, address)
}
add(hostname)
if includeRemoteAddress && remote != nil {
add(remote.String())
}
return addresses
}
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
if strings.TrimSpace(file) == "" {
return ErrKnownHostsFileRequired
}
if strings.TrimSpace(line) == "" {
return errors.New("known_hosts line is empty")
}
dir := filepath.Dir(file)
if dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
}
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
if err != nil {
return err
}
defer handle.Close()
if _, err := handle.WriteString(line + "\n"); err != nil {
return err
}
return handle.Chmod(mode)
}
func normalizeHostKeyFingerprint(raw string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", nil
}
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
suffix := strings.TrimSpace(value[len("SHA256:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return "SHA256:" + suffix, nil
}
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
suffix := strings.TrimSpace(value[len("MD5:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return normalizeMD5Fingerprint("MD5:" + suffix), nil
}
if strings.Count(value, ":") >= 2 {
return normalizeMD5Fingerprint("MD5:" + value), nil
}
return "SHA256:" + value, nil
}
func normalizeMD5Fingerprint(value string) string {
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
return "MD5:" + strings.ToLower(value)
}
return "MD5:" + strings.ToLower(value[len("MD5:"):])
}
+135
View File
@@ -0,0 +1,135 @@
package starssh
import (
"context"
"sync"
"time"
)
var sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
return err
}
func (s *StarSSH) Ping() error {
return s.PingContext(context.Background())
}
func (s *StarSSH) PingContext(ctx context.Context) error {
return s.pingContext(ctx, nil)
}
func (s *StarSSH) PingContextCloseOnCancel(ctx context.Context) error {
return s.pingContext(ctx, s.Close)
}
func (s *StarSSH) pingContext(ctx context.Context, onCancel func() error) error {
if ctx == nil {
ctx = context.Background()
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
runCancel := func() {}
if onCancel != nil {
var cancelOnce sync.Once
runCancel = func() {
cancelOnce.Do(func() {
_ = onCancel()
})
}
cancelDone := make(chan struct{})
defer close(cancelDone)
go func() {
select {
case <-ctx.Done():
runCancel()
case <-cancelDone:
}
}()
}
requestFunc := sendKeepAliveRequest
pingErr := make(chan error, 1)
go func() {
err := requestFunc(ctx, client)
select {
case pingErr <- err:
default:
}
}()
select {
case err := <-pingErr:
return err
case <-ctx.Done():
runCancel()
return ctx.Err()
}
}
func (s *StarSSH) startAutoKeepAlive() {
if s == nil || s.snapshotSSHClient() == nil {
return
}
interval := s.LoginInfo.KeepAliveInterval
if interval <= 0 {
return
}
timeout := s.LoginInfo.KeepAliveTimeout
if timeout <= 0 {
timeout = defaultKeepAliveTimeout
}
stop := make(chan struct{})
done := make(chan struct{})
s.keepaliveMu.Lock()
if s.keepaliveStop != nil {
s.keepaliveMu.Unlock()
return
}
s.keepaliveStop = stop
s.keepaliveDone = done
s.keepaliveMu.Unlock()
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
defer close(done)
for {
select {
case <-stop:
return
case <-ticker.C:
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
err := s.pingContext(pingCtx, nil)
cancel()
if err != nil {
s.closeFromKeepAlive()
return
}
}
}
}()
}
func (s *StarSSH) stopAutoKeepAlive() {
stop, done := s.takeKeepaliveHandles()
if stop != nil {
close(stop)
}
if done != nil {
<-done
}
}
func (s *StarSSH) closeFromKeepAlive() {
_ = s.closeTransport(false)
}
+53
View File
@@ -0,0 +1,53 @@
package starssh
import (
"context"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestAutoKeepAliveTimeoutDoesNotDeadlock(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
return ctx.Err()
}
closeSSHClient = func(client sshClientRequester) error {
return nil
}
client := &ssh.Client{}
star := &StarSSH{
LoginInfo: LoginInput{
KeepAliveInterval: 10 * time.Millisecond,
KeepAliveTimeout: 20 * time.Millisecond,
},
}
star.setTransport(client, nil)
star.startAutoKeepAlive()
star.keepaliveMu.Lock()
done := star.keepaliveDone
star.keepaliveMu.Unlock()
if done == nil {
t.Fatal("keepalive did not start")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("keepalive goroutine did not exit after keepalive timeout")
}
if client := star.snapshotSSHClient(); client != nil {
t.Fatal("ssh client should be closed after keepalive timeout")
}
}
+193
View File
@@ -0,0 +1,193 @@
package starssh
import (
"context"
"encoding/base64"
"errors"
"net"
"os"
"time"
"golang.org/x/crypto/ssh"
)
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
}
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return loginWithContext(ctx, info)
}
func Login(info LoginInput) (*StarSSH, error) {
return LoginContext(context.Background(), info)
}
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
info = normalizeLoginInput(info)
if info.HostKeyCallback == nil {
return nil, ErrHostKeyCallbackRequired
}
authTimeout := effectiveLoginTimeout(info)
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
defer cancel()
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, err
}
if shouldRetrySSHAgentAuth(info, order) {
agentAttempt := newSSHAgentAuthAttempt()
for {
agentAttempt.begin()
sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt)
if err == nil {
return sshInfo, nil
}
if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil {
continue
}
return sshInfo, err
}
}
return loginOnceWithContext(loginCtx, info, authTimeout, nil)
}
func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) {
sshInfo := &StarSSH{
LoginInfo: info,
}
auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt)
if err != nil {
return nil, err
}
if authCleanup != nil {
defer authCleanup()
}
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
return info.HostKeyCallback(hostname, remote, key)
}
bannerCallback := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig := &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: authTimeout,
HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback,
}
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
clientConfig.Config = ssh.Config{
Ciphers: info.Ciphers,
MACs: info.MACs,
KeyExchanges: info.KeyExchanges,
}
}
targetAddr := joinHostPort(info.Addr, info.Port)
rawConn, upstream, err := dialTargetConn(ctx, info)
if err != nil {
return sshInfo, err
}
restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout)
defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
if err != nil {
_ = rawConn.Close()
if upstream != nil {
_ = upstream.Close()
}
return sshInfo, err
}
client := ssh.NewClient(clientConn, chans, reqs)
sshInfo.setTransport(client, upstream)
if sshInfo.PublicKey != nil {
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
sshInfo.startAutoKeepAlive()
return sshInfo, nil
}
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
if timeout <= 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, timeout)
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
info := LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
DialTimeout: timeout,
User: user,
HostKeyCallback: DefaultAllowHostKeyCallback,
}
if prikeyPath != "" {
prikey, err := os.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 {
info.Port = defaultSSHPort
}
return info
}
func effectiveLoginTimeout(info LoginInput) time.Duration {
if info.Timeout <= 0 {
return 0
}
return info.Timeout
}
func effectiveDialTimeout(info LoginInput) time.Duration {
switch {
case info.DialTimeout < 0:
return 0
case info.DialTimeout > 0:
return info.DialTimeout
case info.Timeout > 0:
return info.Timeout
default:
return defaultLoginTimeout
}
}
+763
View File
@@ -0,0 +1,763 @@
package starssh
import (
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"errors"
"io"
"net"
"os"
"sync"
"testing"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
info := normalizeLoginInput(LoginInput{})
if info.Port != defaultSSHPort {
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
}
if info.Timeout != 0 {
t.Fatalf("Timeout=%v want 0", info.Timeout)
}
if info.DialTimeout != 0 {
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
}
if info.SSHAgentTimeout != 0 {
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
}
if info.SSHAgentForwardTimeout != 0 {
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
}
}
func TestEffectiveLoginTimeout(t *testing.T) {
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero login timeout should stay zero, got %v", got)
}
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
t.Fatalf("expected explicit login timeout, got %v", got)
}
}
func TestEffectiveDialTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback",
info: LoginInput{},
want: defaultLoginTimeout,
},
{
name: "reuse timeout when dial timeout omitted",
info: LoginInput{Timeout: 9 * time.Second},
want: 9 * time.Second,
},
{
name: "explicit dial timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
want: 3 * time.Second,
},
{
name: "negative dial timeout disables default dial deadline",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveDialTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestEffectiveSSHAgentTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback without auth timeout",
info: LoginInput{},
want: defaultSSHAgentTimeout,
},
{
name: "auth timeout does not cap default",
info: LoginInput{Timeout: 9 * time.Second},
want: defaultSSHAgentTimeout,
},
{
name: "explicit agent timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
want: 90 * time.Second,
},
{
name: "negative agent timeout disables operation deadline",
info: LoginInput{SSHAgentTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero forward timeout should stay zero, got %v", got)
}
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
t.Fatalf("expected explicit forward timeout, got %v", got)
}
}
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
captured = timeouts
return ssh.Password("agent"), nil, nil
}
info := LoginInput{
Timeout: 0,
DialTimeout: 11 * time.Second,
SSHAgentTimeout: 90 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
IdentityAgent: "/tmp/custom-agent.sock",
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}
auth, cleanup, err := buildAuthMethods(info)
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("expected one auth method, got %d", len(auth))
}
if captured.Dial != 11*time.Second {
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
}
if captured.Operation != 90*time.Second {
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
}
if captured.Forward != 4*time.Second {
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
}
if captured.Endpoint != "/tmp/custom-agent.sock" {
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
}
}
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethods(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
})
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("auth methods=%d, want 1", len(auth))
}
}
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
}
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
}
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
}
}
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
if timeouts.SignFailure == nil {
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
}
if timeouts.SkipFingerprints != nil {
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
}
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
Password: "secret",
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
}, newSSHAgentAuthAttempt())
if err != nil {
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 2 {
t.Fatalf("auth methods=%d want 2", len(auth))
}
}
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
attempt := newSSHAgentAuthAttempt()
attempt.skipFingerprint("SHA256:test")
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
if err := checkSSHAgentRetryPending(attempt); err != nil {
t.Fatalf("retry should reset on next attempt: %v", err)
}
}
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
signer := mustGenerateTestSigner(t)
attempt := newSSHAgentAuthAttempt()
callback := privateKeySignersCallback(signer, attempt)
signers, err := callback()
if err != nil {
t.Fatalf("private key callback before retry: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback returned %#v, want original signer", signers)
}
attempt.skipFingerprint("SHA256:test")
signers, err = callback()
if !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
}
if signers != nil {
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
}
attempt.begin()
signers, err = callback()
if err != nil {
t.Fatalf("private key callback after retry reset: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
}
}
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
attempt := newSSHAgentAuthAttempt()
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(firstMethods) != 2 {
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
}
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
}
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(secondMethods) != 1 {
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
}
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second auth method did not skip failed first key")
}
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
if err != nil {
t.Fatalf("second signer Sign: %v", err)
}
if signature == nil {
t.Fatal("second signer returned nil signature")
}
}
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
wantErr := errors.New("first agent key cannot sign")
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
var buildCalls int
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
buildCalls++
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
if buildCalls == 1 {
if len(filteredSigners) != 2 {
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
if len(filteredSigners) != 1 {
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
}
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatal("retry build did not skip failed signer")
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
attempt := newSSHAgentAuthAttempt()
attempt.begin()
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("first auth methods=%d want 1", len(auth))
}
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
t.Fatalf("raw failing signer err=%v", err)
}
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
})[0]
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("retry auth methods=%d want 1", len(auth))
}
if buildCalls != 2 {
t.Fatalf("build calls=%d want 2", buildCalls)
}
}
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
t.Fatalf("first signer changed order without hints")
}
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second signer changed order without hints")
}
}
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
signer := mustGenerateTestSigner(t)
wantErr := errors.New("agent refused operation")
var debugCalls int
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
Resolved: resolvedSSHAgentEndpoint{
Endpoint: "/tmp/debug-agent.sock",
Source: "identity-agent",
Network: "unix",
},
Debug: func(event SSHAgentDebugEvent) {
debugCalls++
if event.Step != "auth" || event.Phase != "sign" {
t.Fatalf("unexpected debug event: %+v", event)
}
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
t.Fatalf("unexpected endpoint details: %+v", event)
}
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
t.Fatalf("unexpected sign status: %+v", event)
}
},
})
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
if !errors.Is(err, wantErr) {
t.Fatalf("Sign err=%v want original signer error", err)
}
if debugCalls != 1 {
t.Fatalf("debug calls=%d want 1", debugCalls)
}
}
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
spy := &testAlgorithmSpySigner{Signer: signer}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
}
if signature.Format != ssh.KeyAlgoRSASHA256 {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
}
}
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
if err != nil {
t.Fatalf("NewSignerWithAlgorithms: %v", err)
}
spy := &testMultiAlgorithmSpySigner{
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
}
if signature.Format != ssh.KeyAlgoRSA {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
}
}
type deadlineSpyConn struct {
net.Conn
mu sync.Mutex
deadlines []time.Time
readErr error
writeErr error
}
type testFailingSigner struct {
ssh.Signer
err error
}
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
return nil, s.err
}
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
return nil, s.err
}
type testAlgorithmSpySigner struct {
ssh.Signer
lastAlgorithm string
}
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
s.lastAlgorithm = algorithm
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
}
type testMultiAlgorithmSpySigner struct {
*testAlgorithmSpySigner
}
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
return multiAlgorithmSigner.Algorithms()
}
return nil
}
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
t.Helper()
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new test signer: %v", err)
}
return signer
}
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
t.Helper()
baseSigner := mustGenerateTestSigner(t)
publicKey := baseSigner.PublicKey()
return &commentedTestSigner{
Signer: baseSigner,
publicKey: &sshagent.Key{
Format: publicKey.Type(),
Blob: publicKey.Marshal(),
Comment: comment,
},
}
}
type commentedTestSigner struct {
ssh.Signer
publicKey ssh.PublicKey
}
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate rsa test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new rsa test signer: %v", err)
}
return signer
}
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.deadlines = append(c.deadlines, deadline)
return nil
}
func (c *deadlineSpyConn) deadlineCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.deadlines)
}
func (c *deadlineSpyConn) firstDeadline() time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.deadlines[0]
}
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
if c.readErr != nil {
return 0, c.readErr
}
return 0, nil
}
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
if c.writeErr != nil {
return 0, c.writeErr
}
return len(p), nil
}
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
spy := &deadlineSpyConn{readErr: io.EOF}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
buf := make([]byte, 1)
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
t.Fatalf("Read err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
t.Fatalf("deadline=%v should be in the future", firstDeadline)
}
}
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
spy := &deadlineSpyConn{}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
if _, err := conn.Write([]byte("x")); err != nil {
t.Fatalf("Write err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
}
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/identity-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "identity-agent" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/env-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "SSH_AUTH_SOCK" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return client, nil
}
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
Operation: 20 * time.Millisecond,
Endpoint: "/tmp/hung-agent.sock",
})
if cleanup != nil {
cleanup()
}
if !errors.Is(err, ErrSSHAgentTimeout) {
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
}
}
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
socketPath := tempUnixSocketPath(t)
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen unix: %v", err)
}
defer listener.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := listener.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
var events []SSHAgentDebugEvent
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
Dial: time.Second,
Operation: time.Second,
Endpoint: socketPath,
Debug: func(event SSHAgentDebugEvent) {
events = append(events, event)
},
})
<-done
if len(events) == 0 {
t.Fatal("expected debug events")
}
if events[0].Step != "auth" || events[0].Phase != "dial" {
t.Fatalf("unexpected first event: %+v", events[0])
}
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
t.Fatalf("unexpected endpoint event: %+v", events[0])
}
}
func tempUnixSocketPath(t *testing.T) string {
t.Helper()
path := t.TempDir() + "/agent.sock"
t.Cleanup(func() {
_ = os.Remove(path)
})
return path
}
+466
View File
@@ -0,0 +1,466 @@
package starssh
import (
"context"
"errors"
"sync"
"time"
)
const defaultExecPoolMaxOpenConns = 4
var ErrExecPoolClosed = errors.New("exec pool is closed")
type ExecPoolConfig struct {
Login LoginInput
MaxOpenConns int
MaxIdleConns int
MaxIdleTime time.Duration
DisableHealthCheck bool
HealthCheckTimeout time.Duration
}
type ExecPoolStats struct {
MaxOpenConns int
MaxIdleConns int
MaxIdleTime time.Duration
OpenConns int
IdleConns int
InUseConns int
}
type ExecPool struct {
loginInfo LoginInput
maxOpen int
maxIdle int
maxIdleTime time.Duration
idle chan *pooledClient
done chan struct{}
closeOnce sync.Once
healthCheckOnAcquire bool
healthCheckTimeout time.Duration
mu sync.Mutex
open int
closed bool
}
type pooledClient struct {
client *StarSSH
idleAt time.Time
}
func NewExecPool(config ExecPoolConfig) *ExecPool {
maxOpen := config.MaxOpenConns
if maxOpen <= 0 {
maxOpen = defaultExecPoolMaxOpenConns
}
maxIdle := config.MaxIdleConns
if maxIdle <= 0 || maxIdle > maxOpen {
maxIdle = maxOpen
}
return &ExecPool{
loginInfo: config.Login,
maxOpen: maxOpen,
maxIdle: maxIdle,
maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime),
idle: make(chan *pooledClient, maxIdle),
done: make(chan struct{}),
healthCheckOnAcquire: !config.DisableHealthCheck,
healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout),
}
}
func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) {
client, err := p.Acquire(ctx)
if err != nil {
return nil, err
}
result, execErr := client.Exec(ctx, req)
if execErr != nil {
p.Discard(client)
return result, execErr
}
if releaseErr := p.Release(client); releaseErr != nil {
return result, releaseErr
}
return result, nil
}
func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) {
return p.Exec(ctx, ExecRequest{
Command: command,
})
}
func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
client, err := p.Acquire(ctx)
if err != nil {
return nil, err
}
result, execErr := client.ExecStream(ctx, req, onChunk)
if execErr != nil {
p.Discard(client)
return result, execErr
}
if releaseErr := p.Release(client); releaseErr != nil {
return result, releaseErr
}
return result, nil
}
func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error {
if p == nil {
return errors.New("exec pool is nil")
}
if ctx == nil {
ctx = context.Background()
}
targetIdle = p.normalizeWarmUpTarget(targetIdle)
if targetIdle == 0 {
return nil
}
for {
if err := ctx.Err(); err != nil {
return err
}
idleCount, create, err := p.tryWarmUp(targetIdle)
if err != nil {
return err
}
if idleCount >= targetIdle || !create {
return nil
}
conn, err := LoginContext(ctx, p.loginInfo)
if err != nil {
p.releaseSlot()
return err
}
if err := p.Release(conn); err != nil {
return err
}
}
}
func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) {
if p == nil {
return nil, errors.New("exec pool is nil")
}
if ctx == nil {
ctx = context.Background()
}
for {
idleClient, create, err := p.tryAcquire()
if err != nil {
return nil, err
}
if idleClient != nil {
client, ok := p.takeIdleClient(ctx, idleClient)
if !ok {
continue
}
return client, nil
}
if create {
conn, err := LoginContext(ctx, p.loginInfo)
if err != nil {
p.releaseSlot()
return nil, err
}
return conn, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-p.done:
return nil, ErrExecPoolClosed
case idleClient = <-p.idle:
if idleClient == nil {
continue
}
client, ok := p.takeIdleClient(ctx, idleClient)
if !ok {
continue
}
return client, nil
}
}
}
func (p *ExecPool) Release(client *StarSSH) error {
if p == nil {
return errors.New("exec pool is nil")
}
if client == nil {
p.releaseSlot()
return nil
}
p.mu.Lock()
if p.closed {
p.mu.Unlock()
p.closeClient(client)
return nil
}
select {
case p.idle <- &pooledClient{
client: client,
idleAt: time.Now(),
}:
p.mu.Unlock()
return nil
default:
p.mu.Unlock()
p.closeClient(client)
return nil
}
}
func (p *ExecPool) Discard(client *StarSSH) {
if p == nil {
return
}
if client == nil {
p.releaseSlot()
return
}
p.closeClient(client)
}
func (p *ExecPool) Stats() ExecPoolStats {
if p == nil {
return ExecPoolStats{}
}
p.mu.Lock()
defer p.mu.Unlock()
idleCount := len(p.idle)
openCount := p.open
inUseCount := openCount - idleCount
if inUseCount < 0 {
inUseCount = 0
}
return ExecPoolStats{
MaxOpenConns: p.maxOpen,
MaxIdleConns: p.maxIdle,
MaxIdleTime: p.maxIdleTime,
OpenConns: openCount,
IdleConns: idleCount,
InUseConns: inUseCount,
}
}
func (p *ExecPool) Close() error {
if p == nil {
return nil
}
var closeErr error
p.closeOnce.Do(func() {
p.mu.Lock()
p.closed = true
idleClients := p.drainIdleLocked()
p.mu.Unlock()
close(p.done)
for _, client := range idleClients {
if err := client.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
})
return closeErr
}
func (p *ExecPool) CloseIdle() error {
if p == nil {
return nil
}
p.mu.Lock()
idleClients := p.drainIdleLocked()
p.mu.Unlock()
var closeErr error
for _, client := range idleClients {
if err := client.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
return closeErr
}
func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, false, ErrExecPoolClosed
}
select {
case client := <-p.idle:
return client, false, nil
default:
}
if p.open < p.maxOpen {
p.open++
return nil, true, nil
}
return nil, false, nil
}
func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, false, ErrExecPoolClosed
}
idleCount := len(p.idle)
if idleCount >= targetIdle {
return idleCount, false, nil
}
if p.open >= p.maxOpen {
return idleCount, false, nil
}
p.open++
return idleCount, true, nil
}
func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int {
if p == nil || p.maxIdle <= 0 {
return 0
}
if targetIdle <= 0 {
return p.maxIdle
}
if targetIdle > p.maxIdle {
return p.maxIdle
}
return targetIdle
}
func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) {
if idleClient == nil {
return nil, false
}
if idleClient.client == nil {
p.releaseSlot()
return nil, false
}
if p.isIdleExpired(idleClient) {
p.closePooledClient(idleClient)
return nil, false
}
if err := p.healthCheckClient(ctx, idleClient.client); err != nil {
p.closePooledClient(idleClient)
return nil, false
}
return idleClient.client, true
}
func (p *ExecPool) isIdleExpired(client *pooledClient) bool {
if p == nil || client == nil || client.client == nil {
return false
}
if p.maxIdleTime <= 0 || client.idleAt.IsZero() {
return false
}
return time.Since(client.idleAt) >= p.maxIdleTime
}
func (p *ExecPool) drainIdleLocked() []*StarSSH {
clients := make([]*StarSSH, 0, len(p.idle))
for {
select {
case idleClient := <-p.idle:
if p.open > 0 {
p.open--
}
if idleClient == nil || idleClient.client == nil {
continue
}
clients = append(clients, idleClient.client)
default:
return clients
}
}
}
func (p *ExecPool) releaseSlot() {
p.mu.Lock()
defer p.mu.Unlock()
if p.open > 0 {
p.open--
}
}
func (p *ExecPool) closeClient(client *StarSSH) {
if client != nil {
_ = client.Close()
}
p.releaseSlot()
}
func (p *ExecPool) closePooledClient(client *pooledClient) {
if client == nil {
return
}
if client.client == nil {
p.releaseSlot()
return
}
p.closeClient(client.client)
}
func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error {
if client == nil {
return errors.New("ssh client is nil")
}
if !p.healthCheckOnAcquire {
return nil
}
if ctx == nil {
ctx = context.Background()
}
timeout := p.healthCheckTimeout
if timeout > 0 {
healthCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = healthCtx
}
return client.PingContext(ctx)
}
func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration {
if timeout <= 0 {
return defaultKeepAliveTimeout
}
return timeout
}
func normalizeMaxIdleTime(timeout time.Duration) time.Duration {
if timeout <= 0 {
return 0
}
return timeout
}
+145
View File
@@ -0,0 +1,145 @@
package starssh
import (
"errors"
"io"
"net"
"strings"
"golang.org/x/crypto/ssh"
)
var newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return client.NewSession()
}
var requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
return session.RequestPty(config.Term, config.Rows, config.Columns, config.Modes)
}
func (s *StarSSH) Close() error {
return s.closeTransport(true)
}
func (s *StarSSH) NewSession() (*ssh.Session, error) {
return s.NewPTYSession(nil)
}
func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
session, err := NewExecSession(client)
if err != nil {
return nil, err
}
if err := s.maybeRequestAgentForwarding(session); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
session, err := NewPTYSession(client, config)
if err != nil {
return nil, err
}
if err := s.maybeRequestAgentForwarding(session); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
return NewExecSession(client)
}
func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
return newSSHSession(client)
}
func NewSession(client *ssh.Client) (*ssh.Session, error) {
return NewPTYSession(client, nil)
}
func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
session, err := newSSHSession(client)
if err != nil {
return nil, err
}
cfg := normalizeTerminalConfig(config)
if err := requestSessionPTY(session, cfg); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func normalizeTerminalConfig(config *TerminalConfig) TerminalConfig {
cfg := TerminalConfig{
Term: defaultPTYTerm,
Rows: defaultPTYRows,
Columns: defaultPTYColumns,
Modes: ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
},
}
if config == nil {
return cfg
}
if strings.TrimSpace(config.Term) != "" {
cfg.Term = config.Term
}
if config.Rows > 0 {
cfg.Rows = config.Rows
}
if config.Columns > 0 {
cfg.Columns = config.Columns
}
if len(config.Modes) > 0 {
cfg.Modes = config.Modes
}
return cfg
}
func normalizeAlreadyClosedError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, net.ErrClosed) {
return nil
}
if errors.Is(err, io.EOF) {
return nil
}
if strings.Contains(err.Error(), "use of closed network connection") {
return nil
}
return err
}
func closeUpstream(upstream *StarSSH) error {
if upstream == nil {
return nil
}
return upstream.Close()
}
+1795 -114
View File
File diff suppressed because it is too large Load Diff
+1127
View File
File diff suppressed because it is too large Load Diff
+592
View File
@@ -0,0 +1,592 @@
package starssh
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strings"
"time"
)
var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells")
type ShellRequest struct {
Command string
Timeout time.Duration
Keyword string
UseWaitDefault *bool
}
// NewShell creates the legacy prompt-driven POSIX shell helper.
// For raw interactive terminal flows, prefer NewTerminal.
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
shell = &StarShell{
UseWaitDefault: true,
WaitTimeout: defaultShellWaitTimeout,
isecho: true,
iscolor: true,
promptToken: defaultShellPromptToken,
}
shell.Session, err = s.NewPTYSession(nil)
if err != nil {
return nil, err
}
shell.in, err = shell.Session.StdinPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
stdout, err := shell.Session.StdoutPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.out = bufio.NewReader(stdout)
stderr, err := shell.Session.StderrPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.er = bufio.NewReader(stderr)
if err := shell.Session.Shell(); err != nil {
_ = shell.Session.Close()
return nil, err
}
go shell.watchSession()
shell.gohub()
if err := shell.configurePrompt(context.Background()); err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.Clear()
return shell, nil
}
func (s *StarShell) configurePrompt(ctx context.Context) error {
if s == nil {
return errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout)
defer cancel()
ctx = timeoutCtx
}
prompt := s.promptToken + " "
setupCommands := []string{
"unset PROMPT_COMMAND >/dev/null 2>&1 || true",
fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)),
}
s.Clear()
for _, cmd := range setupCommands {
if err := s.WriteCommand(cmd); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
}
probeToken := "__STARSSH_POSIX_READY__" + newNonce(6)
if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
ticker := time.NewTicker(defaultShellPollInterval)
defer ticker.Stop()
for {
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr)
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
if strings.Contains(outs, probeToken) {
s.Clear()
return nil
}
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
if looksLikeNonPOSIXShellError(errs) {
return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs)
}
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly)
}
return ctx.Err()
case <-ticker.C:
}
}
}
func (s *StarShell) watchSession() {
if s == nil || s.Session == nil {
return
}
if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) {
s.setError(err)
}
}
func (s *StarShell) Close() error {
if s == nil {
return nil
}
var closeErr error
s.closeOnce.Do(func() {
if s.Session == nil {
return
}
closeErr = s.Session.Close()
})
return closeErr
}
func (s *StarShell) SwitchNoColor(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.iscolor = run
}
func (s *StarShell) SwitchEcho(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isecho = run
}
func (s *StarShell) TrimColor(str string) string {
s.rw.RLock()
shouldTrim := s.iscolor
s.rw.RUnlock()
if shouldTrim {
return SedColor(str)
}
return str
}
/*
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
*/
func (s *StarShell) SwitchPrint(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isprint = run
}
/*
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
*/
func (s *StarShell) SwitchFunc(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isfuncs = run
}
func (s *StarShell) SetFunc(funcs func(string)) {
s.rw.Lock()
defer s.rw.Unlock()
s.funcs = funcs
}
func (s *StarShell) Clear() {
s.rw.Lock()
defer s.rw.Unlock()
s.outbyte = []byte{}
s.errbyte = []byte{}
}
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
s.Clear()
defer s.Clear()
return s.Shell(cmd, sleep)
}
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
s.commandMu.Lock()
defer s.commandMu.Unlock()
if err := s.WriteCommand(cmd); err != nil {
return "", "", err
}
outRaw, errRaw, runErr := s.GetResult(sleep)
if runErr != nil {
return "", "", runErr
}
outText := s.TrimColor(strings.TrimSpace(string(outRaw)))
s.rw.RLock()
echoEnabled := s.isecho
s.rw.RUnlock()
if echoEnabled {
outText = stripCommandEchoFromOutput(outText, cmd)
}
return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil
}
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
result, err := s.Run(context.Background(), ShellRequest{
Command: cmd,
})
if err != nil {
return "", "", err
}
return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError()
}
func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) {
return s.Run(ctx, ShellRequest{
Command: command,
})
}
func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) {
if s == nil {
return nil, errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
s.commandMu.Lock()
defer s.commandMu.Unlock()
if strings.TrimSpace(req.Command) == "" {
return nil, errors.New("command is empty")
}
s.rw.RLock()
useDefault := s.UseWaitDefault
keyword := s.Keyword
waitTimeout := s.WaitTimeout
promptToken := s.promptToken
s.rw.RUnlock()
if req.UseWaitDefault != nil {
useDefault = *req.UseWaitDefault
}
if strings.TrimSpace(req.Keyword) != "" {
keyword = req.Keyword
}
if req.Timeout > 0 {
waitTimeout = req.Timeout
}
if !useDefault && keyword == "" {
return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set")
}
if waitTimeout <= 0 {
waitTimeout = defaultShellWaitTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
ctx = timeoutCtx
}
startAt := time.Now()
s.Clear()
defer s.Clear()
beginToken, endToken := newCommandTokens()
markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken)
if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil {
return nil, err
}
if err := s.WriteCommand(req.Command); err != nil {
return nil, err
}
if err := s.WriteCommand(markerCmd); err != nil {
return nil, err
}
var (
outc string
errc string
exitCode int
done bool
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(defaultShellPollInterval):
}
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return nil, runErr
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
s.rw.RLock()
useDefault = s.UseWaitDefault
keyword = s.Keyword
s.rw.RUnlock()
if useDefault {
segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken)
if parseErr != nil {
return nil, parseErr
}
if found {
outc = segment
errc = errs
exitCode = rc
done = true
break
}
}
if keyword != "" {
if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) {
outc = outs
errc = errs
done = true
break
}
}
}
if !done {
return nil, errors.New("failed to collect shell result")
}
outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken)
errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken)
outc = stripCommandEchoFromOutput(outc, req.Command)
stdoutText := strings.TrimSpace(outc)
stderrText := strings.TrimSpace(errc)
result := &ExecResult{
Command: req.Command,
Stdout: []byte(stdoutText),
Stderr: []byte(stderrText),
Combined: combineCommandOutput(stdoutText, stderrText),
Duration: time.Since(startAt),
}
if useDefault {
result.ExitCode = exitCode
}
return result, nil
}
func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) {
lines := strings.Split(stdout, "\n")
beginLine := -1
for i, line := range lines {
if strings.TrimSpace(line) == beginToken {
beginLine = i
break
}
}
if beginLine < 0 {
return "", 0, false, nil
}
segment := strings.Join(lines[beginLine+1:], "\n")
before, rc, found, err := splitByEndToken(segment, endToken)
if err != nil {
return "", 0, false, err
}
if !found {
return "", 0, false, nil
}
return strings.TrimSpace(before), rc, true, nil
}
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
if sleep > 0 {
time.Sleep(time.Millisecond * time.Duration(sleep))
}
return s.readState()
}
func (s *StarShell) WriteCommand(cmd string) error {
return s.Write([]byte(cmd + "\n"))
}
func (s *StarShell) Write(bstr []byte) error {
if s == nil {
return errors.New("shell is nil")
}
if s.in == nil {
return errors.New("shell stdin is not initialized")
}
if _, _, runErr := s.readState(); runErr != nil {
return runErr
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.in.Write(bstr)
return err
}
func (s *StarShell) gohub() {
if s.er != nil {
go s.streamPump(s.er, true)
}
if s.out != nil {
go s.streamPump(s.out, false)
}
}
func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) {
var cache []byte
for {
read, err := reader.ReadByte()
if err == io.EOF {
return
}
if err != nil {
s.setError(err)
return
}
s.rw.Lock()
if isStderr {
s.errbyte = append(s.errbyte, read)
} else {
s.outbyte = append(s.outbyte, read)
}
printEnabled := s.isprint
funcEnabled := s.isfuncs && s.funcs != nil
lineHandler := s.funcs
trimColor := s.iscolor
s.rw.Unlock()
if printEnabled {
fmt.Print(string([]byte{read}))
}
cache = append(cache, read)
if read == '\n' {
if funcEnabled {
line := strings.TrimSpace(string(cache))
if trimColor {
line = SedColor(line)
}
go lineHandler(line)
}
cache = cache[:0]
}
}
}
func (s *StarShell) setError(err error) {
if err == nil || errors.Is(err, io.EOF) {
return
}
s.rw.Lock()
defer s.rw.Unlock()
if s.errors == nil {
s.errors = err
}
}
func (s *StarShell) readState() ([]byte, []byte, error) {
s.rw.RLock()
defer s.rw.RUnlock()
outCopy := make([]byte, len(s.outbyte))
copy(outCopy, s.outbyte)
errCopy := make([]byte, len(s.errbyte))
copy(errCopy, s.errbyte)
if s.errors != nil {
return outCopy, errCopy, s.errors
}
return outCopy, errCopy, nil
}
func stripCommandEchoFromOutput(output string, cmd string) string {
if output == "" || cmd == "" {
return strings.TrimSpace(output)
}
lines := strings.Split(output, "\n")
cmdLines := strings.Split(cmd, "\n")
for _, cmdLine := range cmdLines {
trimmedCmd := strings.TrimSpace(cmdLine)
if trimmedCmd == "" {
continue
}
for i, line := range lines {
if strings.TrimSpace(line) == trimmedCmd {
lines = append(lines[:i], lines[i+1:]...)
break
}
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
func looksLikeNonPOSIXShellError(output string) bool {
if strings.TrimSpace(output) == "" {
return false
}
lower := strings.ToLower(output)
indicators := []string{
"is not recognized as an internal or external command",
"the term ",
"command not found",
"unknown command",
"not found",
"not recognized",
}
for _, indicator := range indicators {
if strings.Contains(lower, indicator) {
return true
}
}
return false
}
func (s *StarShell) GetUid() string {
res, _, _ := s.ShellWait("id -u")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGid() string {
res, _, _ := s.ShellWait("id -g")
return strings.TrimSpace(res)
}
func (s *StarShell) GetUser() string {
res, _, _ := s.ShellWait("id -un")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGroup() string {
res, _, _ := s.ShellWait("id -gn")
return strings.TrimSpace(res)
}
-636
View File
@@ -1,636 +0,0 @@
package starssh
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"net"
"regexp"
"strings"
"sync"
"time"
)
type StarSSH struct {
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
}
type LoginInput struct {
KeyExchanges []string
Ciphers []string
MACs []string
User string
Password string
Prikey string
PrikeyPwd string
Addr string
Port int
Timeout time.Duration
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
BannerCallback func(string) error
}
type StarShell struct {
Keyword string
UseWaitDefault bool
Session *ssh.Session
in io.Writer
out *bufio.Reader
er *bufio.Reader
outbyte []byte
errbyte []byte
lastout int64
errors error
isprint bool
isfuncs bool
iscolor bool
isecho bool
rw sync.RWMutex
funcs func(string)
}
func Login(info LoginInput) (*StarSSH, error) {
var (
auth []ssh.AuthMethod
clientConfig *ssh.ClientConfig
config ssh.Config
err error
)
sshInfo := new(StarSSH)
// get auth method
auth = make([]ssh.AuthMethod, 0)
if info.Prikey == "" {
keyboardInteractiveChallenge := func(
user,
instruction string,
questions []string,
echos []bool,
) (answers []string, err error) {
if len(questions) == 0 {
return []string{}, nil
}
return []string{info.Password}, nil
}
auth = append(auth, ssh.Password(info.Password))
auth = append(auth, ssh.KeyboardInteractive(keyboardInteractiveChallenge))
} else {
pemBytes := []byte(info.Prikey)
var signer ssh.Signer
if info.PrikeyPwd == "" {
signer, err = ssh.ParsePrivateKey(pemBytes)
} else {
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
}
if err != nil {
return nil, err
}
auth = append(auth, ssh.PublicKeys(signer))
}
if len(info.Ciphers) == 0 {
config = ssh.Config{
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc", "chacha20-poly1305@openssh.com"},
}
} else {
config = ssh.Config{
Ciphers: info.Ciphers,
}
}
if len(info.MACs) != 0 {
config.MACs = info.MACs
}
if len(info.KeyExchanges) != 0 {
config.KeyExchanges = info.KeyExchanges
}
if info.Timeout == 0 {
info.Timeout = time.Second * 5
}
hostKeycbfunc := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
if info.HostKeyCallback != nil {
return info.HostKeyCallback(hostname, remote, key)
}
return nil
}
bannercbfunc := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig = &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: info.Timeout,
Config: config,
HostKeyCallback: hostKeycbfunc,
BannerCallback: bannercbfunc,
}
// connet to ssh
sshInfo.LoginInfo = info
sshInfo.Client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", info.Addr, info.Port), clientConfig)
if err == nil && sshInfo.PublicKey != nil {
sshInfo.online = true
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
return sshInfo, err
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
var info = LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
User: user,
}
if prikeyPath != "" {
prikey, err := ioutil.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
var outc, errc string = " ", " "
s.Clear()
defer s.Clear()
echo := "echo b7Y85R56TUY6R5UTb612"
err := s.WriteCommand(cmd)
if err != nil {
return "", "", err
}
time.Sleep(time.Millisecond * 20)
err = s.WriteCommand(echo)
if err != nil {
return "", "", err
}
for {
time.Sleep(time.Millisecond * 120)
outs := string(s.outbyte)
errs := string(s.errbyte)
outs = strings.TrimSpace(strings.ReplaceAll(outs, "\r\n", "\n"))
errs = strings.TrimSpace(strings.ReplaceAll(errs, "\r\n", "\n"))
if len(outs) >= len(cmd+"\n"+echo) && outs[0:len(cmd+"\n"+echo)] == cmd+"\n"+echo {
outs = outs[len(cmd+"\n"+echo):]
} else if len(outs) >= len(cmd) && outs[0:len(cmd)] == cmd {
outs = outs[len(cmd):]
}
if len(errs) >= len(cmd) && errs[0:len(cmd)] == cmd {
errs = errs[len(cmd):]
}
if s.UseWaitDefault {
if strings.Index(string(outs), "b7Y85R56TUY6R5UTb612") >= 0 {
list := strings.Split(string(outs), "\n")
for _, v := range list {
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
outc += v + "\n"
}
}
break
}
if strings.Index(string(errs), "b7Y85R56TUY6R5UTb612") >= 0 {
list := strings.Split(string(errs), "\n")
for _, v := range list {
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
errc += v + "\n"
}
}
break
}
}
if s.Keyword != "" {
if strings.Index(string(outs), s.Keyword) >= 0 {
list := strings.Split(string(outs), "\n")
for _, v := range list {
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
outc += v + "\n"
}
}
break
}
if strings.Index(string(errs), s.Keyword) >= 0 {
list := strings.Split(string(errs), "\n")
for _, v := range list {
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
errc += v + "\n"
}
}
break
}
}
}
return s.TrimColor(strings.TrimSpace(outc)), s.TrimColor(strings.TrimSpace(errc)), err
}
func (s *StarShell) Close() error {
return s.Session.Close()
}
func (s *StarShell) SwitchNoColor(is bool) {
s.iscolor = is
}
func (s *StarShell) SwitchEcho(is bool) {
s.isecho = is
}
func (s *StarShell) TrimColor(str string) string {
if s.iscolor {
return SedColor(str)
}
return str
}
/*
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
*/
func (s *StarShell) SwitchPrint(run bool) {
s.isprint = run
}
/*
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
*/
func (s *StarShell) SwitchFunc(run bool) {
s.isfuncs = run
}
func (s *StarShell) SetFunc(funcs func(string)) {
s.funcs = funcs
}
func (s *StarShell) Clear() {
defer s.rw.Unlock()
s.rw.Lock()
s.outbyte = []byte{}
s.errbyte = []byte{}
time.Sleep(time.Millisecond * 15)
}
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
defer s.Clear()
s.Clear()
return s.Shell(cmd, sleep)
}
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
if err := s.WriteCommand(cmd); err != nil {
return "", "", err
}
tmp1, tmp2, err := s.GetResult(sleep)
tmps := s.TrimColor(strings.TrimSpace(string(tmp1)))
if s.isecho {
n := len(strings.Split(cmd, "\n"))
if n == 1 {
list := strings.SplitN(tmps, "\n", 2)
if len(list) == 2 {
tmps = list[1]
}
} else {
list := strings.Split(tmps, "\n")
cmds := strings.Split(cmd, "\n")
for _, v := range cmds {
for k, v2 := range list {
if strings.TrimSpace(v2) == strings.TrimSpace(v) {
list[k] = ""
break
}
}
}
tmps = ""
for _, v := range list {
if v != "" {
tmps += v + "\n"
}
}
tmps = tmps[0 : len(tmps)-1]
}
}
return tmps, s.TrimColor(strings.TrimSpace(string(tmp2))), err
}
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
if s.errors != nil {
s.Session.Close()
return s.outbyte, s.errbyte, s.errors
}
if sleep > 0 {
time.Sleep(time.Millisecond * time.Duration(sleep))
}
return s.outbyte, s.errbyte, nil
}
func (s *StarShell) WriteCommand(cmd string) error {
return s.Write([]byte(cmd + "\n"))
}
func (s *StarShell) Write(bstr []byte) error {
if s.errors != nil {
s.Session.Close()
return s.errors
}
_, err := s.in.Write(bstr)
return err
}
func (s *StarShell) gohub() {
go func() {
var cache []byte
for {
read, err := s.er.ReadByte()
if err != nil {
s.errors = err
return
}
s.errbyte = append(s.errbyte, read)
if s.isprint {
fmt.Print(string([]byte{read}))
}
cache = append(cache, read)
if read == '\n' {
if s.isfuncs {
go s.funcs(s.TrimColor(strings.TrimSpace(string(cache))))
cache = []byte{}
}
}
}
}()
var cache []byte
for {
read, err := s.out.ReadByte()
if err != nil {
s.errors = err
return
}
s.rw.Lock()
s.outbyte = append(s.outbyte, read)
cache = append(cache, read)
s.rw.Unlock()
if read == '\n' {
if s.isfuncs {
go s.funcs(strings.TrimSpace(string(cache)))
cache = []byte{}
}
}
if s.isprint {
fmt.Print(string([]byte{read}))
}
}
}
func (s *StarShell) GetUid() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=uid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetGid() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=gid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetUser() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetGroup() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
shell = new(StarShell)
shell.Session, err = s.NewSession()
if err != nil {
return
}
shell.in, _ = shell.Session.StdinPipe()
tmp, _ := shell.Session.StdoutPipe()
shell.out = bufio.NewReader(tmp)
tmp, _ = shell.Session.StderrPipe()
shell.er = bufio.NewReader(tmp)
err = shell.Session.Shell()
shell.isecho = true
go shell.Session.Wait()
shell.UseWaitDefault = true
shell.WriteCommand("bash")
time.Sleep(500 * time.Millisecond)
shell.WriteCommand("export PS1= ")
shell.WriteCommand("export PS2= ")
go shell.gohub()
time.Sleep(500 * time.Millisecond)
shell.Clear()
shell.Clear()
return
}
func (s *StarSSH) Close() error {
if s.online {
return s.Client.Close()
}
return nil
}
func (s *StarSSH) NewSession() (*ssh.Session, error) {
return NewSession(s.Client)
}
func (s *StarSSH) ShellOne(cmd string) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
data, err := newsess.CombinedOutput(cmd)
newsess.Close()
return strings.TrimSpace(string(data)), err
}
func (s *StarSSH) Exists(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -e "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) GetUid() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=uid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetGid() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=gid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetUser() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetGroup() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) IsFile(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -f "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) IsFolder(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -d "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
var bytes, errbytes []byte
tmp, _ := newsess.StdoutPipe()
reader := bufio.NewReader(tmp)
tmp, _ = newsess.StderrPipe()
errder := bufio.NewReader(tmp)
err = newsess.Start(cmd)
if err != nil {
return "", err
}
c := make(chan int, 1)
go newsess.Wait()
go func() {
for {
byt, err := reader.ReadByte()
if err != nil {
break
}
fmt.Print(string([]byte{byt}))
bytes = append(bytes, byt)
}
c <- 1
}()
for {
byt, err := errder.ReadByte()
if err != nil {
break
}
fmt.Print(string([]byte{byt}))
errbytes = append(errbytes, byt)
}
_ = <-c
newsess.Close()
if len(errbytes) != 0 {
err = errors.New(strings.TrimSpace(string(errbytes)))
} else {
err = nil
}
return strings.TrimSpace(string(bytes)), err
}
func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
var bytes, errbytes []byte
tmp, _ := newsess.StdoutPipe()
reader := bufio.NewReader(tmp)
tmp, _ = newsess.StderrPipe()
errder := bufio.NewReader(tmp)
err = newsess.Start(cmd)
if err != nil {
return "", err
}
c := make(chan int, 1)
go newsess.Wait()
go func() {
for {
byt, err := reader.ReadByte()
if err != nil {
break
}
callback(string([]byte{byt}))
bytes = append(bytes, byt)
}
c <- 1
}()
for {
byt, err := errder.ReadByte()
if err != nil {
break
}
callback(string([]byte{byt}))
errbytes = append(errbytes, byt)
}
_ = <-c
newsess.Close()
if len(errbytes) != 0 {
err = errors.New(strings.TrimSpace(string(errbytes)))
} else {
err = nil
}
return strings.TrimSpace(string(bytes)), err
}
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
session, err := client.NewSession()
return session, err
}
func NewSession(client *ssh.Client) (*ssh.Session, error) {
var session *ssh.Session
var err error
// create session
if session, err = client.NewSession(); err != nil {
return nil, err
}
modes := ssh.TerminalModes{
ssh.ECHO: 1, // 还是要强制开启
//ssh.IGNCR: 0,
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
if err := session.RequestPty("xterm", 500, 250, modes); err != nil {
return nil, err
}
return session, nil
}
func SedColor(str string) string {
reg := regexp.MustCompile(`\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]`)
//fmt.Println("regexp:", reg.Match([]byte(str)))
return string(reg.ReplaceAll([]byte(str), []byte("")))
}
+668
View File
@@ -0,0 +1,668 @@
package starssh
import (
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
type sshAgentTimeouts struct {
Dial time.Duration
Operation time.Duration
Forward time.Duration
Endpoint string
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SkipFingerprints map[string]struct{}
SignFailure func(ssh.PublicKey, error)
}
type sshAgentAuthAttempt struct {
mu sync.Mutex
skipFingerprints map[string]struct{}
retryRequested bool
}
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func effectiveSSHAgentTimeout(info LoginInput) time.Duration {
switch {
case info.SSHAgentTimeout < 0:
return 0
case info.SSHAgentTimeout > 0:
return info.SSHAgentTimeout
default:
return defaultSSHAgentTimeout
}
}
func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts {
return sshAgentTimeouts{
Dial: effectiveDialTimeout(info),
Operation: effectiveSSHAgentTimeout(info),
Forward: effectiveSSHAgentForwardTimeout(info),
Endpoint: info.IdentityAgent,
Debug: info.SSHAgentDebug,
}
}
func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration {
if info.SSHAgentForwardTimeout > 0 {
return info.SSHAgentForwardTimeout
}
return 0
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
return buildAuthMethodsWithAgentAttempt(info, nil)
}
func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info, agentAttempt)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
timeouts := effectiveSSHAgentTimeouts(info)
if agentAttempt != nil {
timeouts.SkipFingerprints = agentAttempt.skipSnapshot()
timeouts.SignFailure = agentAttempt.recordSignFailure
}
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts)
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool {
if info.DisableSSHAgent {
return false
}
for _, methodKind := range order {
if methodKind == AuthMethodSSHAgent {
return true
}
}
return false
}
func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) {
return func() ([]ssh.Signer, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return []ssh.Signer{signer}, nil
}
}
func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod {
if password == "" && callback == nil {
return nil
}
return ssh.PasswordCallback(func() (string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return "", err
}
if password != "" {
return password, nil
}
return callback()
})
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
agentAttempt *sshAgentAuthAttempt,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return challenge(user, instruction, questions, echos)
})
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation)
started := time.Now()
signers, err := sshagent.NewClient(conn).Signers()
err = normalizeSSHAgentError(err)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "list",
Status: debugStatus(err),
Duration: time.Since(started),
KeyCount: len(signers),
Err: err,
})
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
timeouts.Resolved = resolved
orderedSigners := orderSSHAgentSigners(signers)
filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts)
if len(filteredSigners) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no usable keys")
}
return ssh.PublicKeys(filteredSigners...), func() {
_ = conn.Close()
}, nil
}
func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer {
type orderedSigner struct {
signer ssh.Signer
index int
score int
comment string
}
ordered := make([]orderedSigner, 0, len(signers))
for index, signer := range signers {
if signer == nil || signer.PublicKey() == nil {
continue
}
ordered = append(ordered, orderedSigner{
signer: signer,
index: index,
score: sshAgentSignerPriority(signer),
comment: sshAgentSignerComment(signer),
})
}
sort.SliceStable(ordered, func(i, j int) bool {
if ordered[i].score != ordered[j].score {
return ordered[i].score > ordered[j].score
}
return ordered[i].index < ordered[j].index
})
result := make([]ssh.Signer, 0, len(ordered))
for _, item := range ordered {
result = append(result, item.signer)
}
return result
}
func sshAgentSignerComment(signer ssh.Signer) string {
if signer == nil {
return ""
}
if key, ok := signer.PublicKey().(*sshagent.Key); ok {
return key.Comment
}
return ""
}
func sshAgentSignerPriority(signer ssh.Signer) int {
comment := strings.TrimSpace(sshAgentSignerComment(signer))
if comment == "" {
return 0
}
score := 0
if priority, ok := parseSSHAgentSignerPriority(comment); ok {
score += 100000 + priority*1000
}
lower := strings.ToLower(comment)
if strings.Contains(lower, "current") {
score += 400
}
if strings.Contains(lower, "cardno:") {
score += 300
}
if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") {
score += 100
}
if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") {
score += 50
}
return score
}
func parseSSHAgentSignerPriority(comment string) (int, bool) {
lower := strings.ToLower(comment)
index := strings.Index(lower, "priority=")
if index < 0 {
return 0, false
}
value := strings.TrimSpace(comment[index+len("priority="):])
if value == "" {
return 0, false
}
end := 0
for end < len(value) {
ch := value[end]
if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') {
end++
continue
}
break
}
if end == 0 {
return 0, false
}
priority, err := strconv.Atoi(value[:end])
if err != nil {
return 0, false
}
return priority, true
}
func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer {
filteredSigners := make([]ssh.Signer, 0, len(signers))
for _, signer := range signers {
if signer == nil {
continue
}
publicKey := signer.PublicKey()
if publicKey == nil {
continue
}
if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip {
continue
}
if timeouts.SignFailure == nil && timeouts.Debug == nil {
filteredSigners = append(filteredSigners, signer)
continue
}
filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{
Resolved: timeouts.Resolved,
Debug: timeouts.Debug,
SignFailure: timeouts.SignFailure,
}))
}
return filteredSigners
}
func newSSHAgentAuthAttempt() *sshAgentAuthAttempt {
return &sshAgentAuthAttempt{
skipFingerprints: make(map[string]struct{}),
}
}
func (a *sshAgentAuthAttempt) begin() {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = false
}
func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} {
if a == nil {
return nil
}
a.mu.Lock()
defer a.mu.Unlock()
if len(a.skipFingerprints) == 0 {
return nil
}
snapshot := make(map[string]struct{}, len(a.skipFingerprints))
for fingerprint := range a.skipFingerprints {
snapshot[fingerprint] = struct{}{}
}
return snapshot
}
func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) {
_ = err
if a == nil || publicKey == nil {
return
}
a.skipFingerprint(ssh.FingerprintSHA256(publicKey))
}
func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = true
if fingerprint != "" {
a.skipFingerprints[fingerprint] = struct{}{}
}
}
func (a *sshAgentAuthAttempt) shouldRetry() bool {
if a == nil {
return false
}
a.mu.Lock()
defer a.mu.Unlock()
return a.retryRequested
}
func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error {
if agentAttempt != nil && agentAttempt.shouldRetry() {
return errRetrySSHAgentAuth
}
return nil
}
type sshAgentRetrySigner struct {
signer ssh.Signer
publicKey ssh.PublicKey
options sshAgentSignerOptions
}
type sshAgentRetryAlgorithmSigner struct {
sshAgentRetrySigner
algorithmSigner ssh.AlgorithmSigner
}
type sshAgentRetryMultiAlgorithmSigner struct {
sshAgentRetryAlgorithmSigner
multiAlgorithmSigner ssh.MultiAlgorithmSigner
}
type sshAgentSignerOptions struct {
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SignFailure func(ssh.PublicKey, error)
}
func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer {
return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure})
}
func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer {
publicKey := signer.PublicKey()
base := sshAgentRetrySigner{
signer: signer,
publicKey: publicKey,
options: options,
}
if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok {
return &sshAgentRetryMultiAlgorithmSigner{
sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: multiAlgorithmSigner,
},
multiAlgorithmSigner: multiAlgorithmSigner,
}
}
if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok {
return &sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: algorithmSigner,
}
}
return &base
}
func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
started := time.Now()
signature, err := s.signer.Sign(rand, data)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error {
err = normalizeSSHAgentError(err)
s.logSignDebug(started, err)
if err == nil {
return nil
}
if s.options.SignFailure != nil {
s.options.SignFailure(s.publicKey, err)
return wrapSSHAgentSignError(err)
}
return err
}
func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) {
if s == nil || s.options.Debug == nil {
return
}
logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: s.options.Resolved.Source,
Endpoint: s.options.Resolved.Endpoint,
Network: s.options.Resolved.Network,
Phase: "sign",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
}
func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil)
started := time.Now()
signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string {
return s.multiAlgorithmSigner.Algorithms()
}
func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms())
started := time.Now()
signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string {
if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA {
return requested
}
if len(algorithms) == 0 {
return ssh.KeyAlgoRSASHA256
}
for _, algorithm := range algorithms {
if algorithm == ssh.KeyAlgoRSA {
break
}
if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 {
return algorithm
}
}
return requested
}
func wrapSSHAgentSignError(err error) error {
if err == nil {
return nil
}
return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err))
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}
+158
View File
@@ -0,0 +1,158 @@
package starssh
import (
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
type sshAgentDialOptions struct {
Endpoint string
Timeout time.Duration
}
type resolvedSSHAgentEndpoint struct {
Endpoint string
Source string
Network string
}
type deadlineAgentConn struct {
net.Conn
timeout time.Duration
}
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
endpoint := strings.TrimSpace(options.Endpoint)
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "SSH_AUTH_SOCK",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
return defaultSSHAgentEndpoint()
}
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
resolved, err := resolveSSHAgentEndpoint(options)
if err != nil {
return nil, resolvedSSHAgentEndpoint{}, err
}
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
if isTimeoutError(err) {
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
if err != nil {
return nil, resolved, err
}
return conn, resolved, nil
}
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
options := sshAgentDialOptions{
Endpoint: timeouts.Endpoint,
Timeout: timeouts.Dial,
}
started := time.Now()
conn, resolved, err := dialSSHAgent(options)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: step,
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "dial",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
return conn, resolved, err
}
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
if debug == nil {
return
}
debug(event)
}
func debugStatus(err error) string {
if err != nil {
return "error"
}
return "ok"
}
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
if conn == nil || timeout <= 0 {
return conn
}
return &deadlineAgentConn{Conn: conn, timeout: timeout}
}
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Read(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Write(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) setDeadline() {
if c == nil || c.timeout <= 0 || c.Conn == nil {
return
}
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
}
func isTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, os.ErrDeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
func wrapSSHAgentConnError(err error) error {
if isTimeoutError(err) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}
func normalizeSSHAgentError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, ErrSSHAgentTimeout) {
return err
}
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}
+24
View File
@@ -0,0 +1,24 @@
//go:build !windows
package starssh
import (
"net"
"time"
)
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
return resolvedSSHAgentEndpoint{}, errSSHAgentUnavailable
}
func defaultSSHAgentNetwork(endpoint string) string {
return "unix"
}
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
agentSock := resolved.Endpoint
if timeout > 0 {
return net.DialTimeout("unix", agentSock, timeout)
}
return net.Dial("unix", agentSock)
}
+271
View File
@@ -0,0 +1,271 @@
//go:build windows
package starssh
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
)
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file")
type gpgSocketInfo struct {
port uint16
nonce []byte
cygwin bool
}
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
return resolvedSSHAgentEndpoint{
Endpoint: defaultWindowsSSHAgentPipe,
Source: "platform-default",
Network: "windows-pipe",
}, nil
}
func defaultSSHAgentNetwork(endpoint string) string {
if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
return "windows-pipe"
}
if isAgentSSHSocketPath(endpoint) {
return "gpg-socket"
}
return "unix"
}
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok {
return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default")
}
if isAgentSSHSocketPath(resolved.Endpoint) {
return dialWindowsGPGSocketFile(resolved.Endpoint, timeout)
}
return dialWindowsUnixAgent(resolved.Endpoint, timeout)
}
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound)
}
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
trimmed := strings.TrimSpace(endpoint)
if trimmed == "" {
return "", false
}
normalized := trimmed
if strings.HasPrefix(normalized, "//./pipe/") {
normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/")
}
if strings.HasPrefix(normalized, `\\.\pipe\`) {
return normalized, true
}
return "", false
}
func isWindowsPipeUnavailable(err error) bool {
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
}
func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) {
if timeout > 0 {
return net.DialTimeout("unix", endpoint, timeout)
}
return net.Dial("unix", endpoint)
}
func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0)
}
func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) {
if path == "" {
return nil, fmt.Errorf("gpg agent endpoint is empty")
}
if depth > 8 {
return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if target, ok := parseGPGAssuanSocketRedirect(data); ok {
target = resolveGPGSocketRedirectTarget(path, target)
if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok {
return dialWindowsNamedPipeContext(ctx, pipePath, false)
}
return dialWindowsGPGSocketFileDepth(ctx, target, depth+1)
}
info, err := parseGPGSocketInfo(path, data)
if err != nil {
return nil, err
}
return dialWindowsGPGSocketInfo(ctx, info)
}
func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) {
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port))))
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
_ = conn.Close()
return nil, err
}
}
if _, err := conn.Write(info.nonce); err != nil {
_ = conn.Close()
return nil, err
}
if info.cygwin {
var nonce [16]byte
if _, err := io.ReadFull(conn, nonce[:]); err != nil {
_ = conn.Close()
return nil, err
}
var credential [8]byte
binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid()))
if _, err := conn.Write(credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
if _, err := io.ReadFull(conn, credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
}
_ = conn.SetDeadline(time.Time{})
return conn, nil
}
func resolveGPGSocketRedirectTarget(source string, target string) string {
target = strings.TrimSpace(target)
if target == "" || filepath.IsAbs(target) {
return target
}
if _, ok := normalizeWindowsSSHAgentPipe(target); ok {
return target
}
return filepath.Join(filepath.Dir(source), target)
}
func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) {
if info, ok := parseGPGAssuanSocketInfo(data); ok {
return info, nil
}
if info, ok := parseGPGCygwinSocketInfo(data); ok {
return info, nil
}
return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path)
}
func parseGPGAssuanSocketRedirect(data []byte) (string, bool) {
text := strings.ReplaceAll(string(data), "\r\n", "\n")
text = strings.TrimSuffix(text, "\n")
lines := strings.Split(text, "\n")
if len(lines) != 2 || lines[0] != "%Assuan%" {
return "", false
}
target, ok := strings.CutPrefix(lines[1], "socket=")
if !ok || strings.TrimSpace(target) == "" {
return "", false
}
return os.ExpandEnv(target), true
}
func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) {
newline := bytes.IndexByte(data, '\n')
if newline <= 0 || len(data)-newline-1 != 16 {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 16)
copy(nonce, data[newline+1:])
return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true
}
func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) {
if !bytes.HasPrefix(data, []byte("!<socket >")) {
return gpgSocketInfo{}, false
}
fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00"))
if len(fields) != 3 || fields[1] != "s" {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(fields[0], 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
hexParts := strings.Split(fields[2], "-")
if len(hexParts) != 4 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 0, 16)
for _, part := range hexParts {
if len(part) != 8 {
return gpgSocketInfo{}, false
}
value, err := strconv.ParseUint(part, 16, 32)
if err != nil {
return gpgSocketInfo{}, false
}
var chunk [4]byte
binary.LittleEndian.PutUint32(chunk[:], uint32(value))
nonce = append(nonce, chunk[:]...)
}
return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true
}
func isAgentSSHSocketPath(endpoint string) bool {
normalized := strings.ToLower(strings.TrimSpace(endpoint))
return strings.HasSuffix(normalized, "s.gpg-agent.ssh")
}
func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
conn, err := winio.DialPipeContext(ctx, path)
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
return nil, errSSHAgentUnavailable
}
if err != nil {
return nil, err
}
return conn, nil
}
+152
View File
@@ -0,0 +1,152 @@
//go:build windows
package starssh
import (
"bytes"
"errors"
"io"
"net"
"os"
"path/filepath"
"strconv"
"testing"
"time"
)
func TestParseGPGAssuanSocketInfo(t *testing.T) {
info, ok := parseGPGAssuanSocketInfo([]byte("7247\n0123456789abcdef"))
if !ok {
t.Fatal("expected Assuan socket info to parse")
}
if info.port != 7247 || string(info.nonce) != "0123456789abcdef" || info.cygwin {
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
}
}
func TestParseGPGCygwinSocketInfo(t *testing.T) {
info, ok := parseGPGCygwinSocketInfo([]byte("!<socket >7247 s 00000001-02030405-06070809-0a0b0c0d\x00"))
if !ok {
t.Fatal("expected Cygwin socket info to parse")
}
want := []byte{1, 0, 0, 0, 5, 4, 3, 2, 9, 8, 7, 6, 13, 12, 11, 10}
if info.port != 7247 || string(info.nonce) != string(want) || !info.cygwin {
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
}
}
func TestParseGPGAssuanSocketRedirect(t *testing.T) {
t.Setenv("STARSSH_TEST_PIPE", `\\.\pipe\openssh-ssh-agent`)
target, ok := parseGPGAssuanSocketRedirect([]byte("%Assuan%\r\nsocket=${STARSSH_TEST_PIPE}\r\n"))
if !ok {
t.Fatal("expected Assuan redirect to parse")
}
if target != `\\.\pipe\openssh-ssh-agent` {
t.Fatalf("target=%q", target)
}
}
func TestReadInvalidAgentSSHSocketReturnsGPGSocketError(t *testing.T) {
path := t.TempDir() + "/S.gpg-agent.ssh"
if err := os.WriteFile(path, []byte("not a socket info file"), 0o600); err != nil {
t.Fatalf("write socket file: %v", err)
}
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "SSH_AUTH_SOCK",
Network: defaultSSHAgentNetwork(path),
}, 0)
if !errors.Is(err, errInvalidGPGSocketInfo) {
t.Fatalf("err=%v want errInvalidGPGSocketInfo", err)
}
}
func TestMissingAgentSSHSocketReturnsReadError(t *testing.T) {
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(path),
}, 0)
if err == nil {
t.Fatal("expected missing GPG socket file error")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("err=%v want os.ErrNotExist", err)
}
}
func TestUnreadableAgentSSHSocketReturnsReadError(t *testing.T) {
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
if err := os.Mkdir(path, 0o700); err != nil {
t.Fatalf("mkdir socket path: %v", err)
}
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(path),
}, 0)
if err == nil {
t.Fatal("expected unreadable GPG socket file error")
}
if errors.Is(err, errInvalidGPGSocketInfo) {
t.Fatalf("err=%v should expose read failure before parse", err)
}
}
func TestDialWindowsGPGSocketFilePerformsNonceHandshake(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer listener.Close()
type handshakeResult struct {
nonce []byte
err error
}
resultCh := make(chan handshakeResult, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
resultCh <- handshakeResult{err: err}
return
}
defer conn.Close()
nonce := make([]byte, 16)
if _, err := io.ReadFull(conn, nonce); err != nil {
resultCh <- handshakeResult{err: err}
return
}
resultCh <- handshakeResult{nonce: append([]byte(nil), nonce...)}
}()
socketPath := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
if err := os.WriteFile(socketPath, []byte(strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)+"\n0123456789abcdef"), 0o600); err != nil {
t.Fatalf("write socket file: %v", err)
}
conn, err := dialWindowsGPGSocketFile(socketPath, time.Second)
if err != nil {
t.Fatalf("dialWindowsGPGSocketFile: %v", err)
}
_ = conn.Close()
var result handshakeResult
select {
case result = <-resultCh:
case <-time.After(time.Second):
t.Fatal("listener did not accept GPG socket connection")
}
if result.err != nil {
t.Fatalf("listener handshake error: %v", result.err)
}
if !bytes.Equal(result.nonce, []byte("0123456789abcdef")) {
t.Fatalf("nonce=%q", result.nonce)
}
}
+135
View File
@@ -0,0 +1,135 @@
package starssh
import (
"errors"
"golang.org/x/crypto/ssh"
)
var errSSHClientClosing = errors.New("ssh client is closing")
type sshClientRequester interface {
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
Close() error
}
var closeSSHClient = func(client sshClientRequester) error {
if client == nil {
return nil
}
return client.Close()
}
func (s *StarSSH) snapshotSSHClient() *ssh.Client {
if s == nil {
return nil
}
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return s.Client
}
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
if s == nil {
return nil, errors.New("ssh client is nil")
}
if s.closing.Load() {
return nil, errSSHClientClosing
}
client := s.snapshotSSHClient()
if client == nil {
return nil, errors.New("ssh client is nil")
}
if s.closing.Load() {
return nil, errSSHClientClosing
}
return client, nil
}
func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
if s == nil {
return
}
s.stateMu.Lock()
defer s.stateMu.Unlock()
s.Client = client
s.upstream = upstream
s.online = client != nil
s.closing.Store(false)
}
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
if s == nil {
return nil, nil
}
s.stateMu.Lock()
defer s.stateMu.Unlock()
client := s.Client
upstream := s.upstream
s.Client = nil
s.upstream = nil
s.online = false
return client, upstream
}
func (s *StarSSH) takeKeepaliveHandles() (chan struct{}, chan struct{}) {
if s == nil {
return nil, nil
}
s.keepaliveMu.Lock()
defer s.keepaliveMu.Unlock()
stop := s.keepaliveStop
done := s.keepaliveDone
s.keepaliveStop = nil
s.keepaliveDone = nil
return stop, done
}
func (s *StarSSH) closeTransport(waitKeepalive bool) error {
if s == nil {
return nil
}
s.closing.Store(true)
_ = s.closeReusableSFTPClient()
agentForwarder := s.takeAgentForwarder()
client, upstream := s.detachTransport()
stop, done := s.takeKeepaliveHandles()
if stop != nil {
close(stop)
}
var closeErr error
if agentForwarder != nil {
closeErr = normalizeAlreadyClosedError(agentForwarder.Close())
}
if client != nil {
if err := normalizeAlreadyClosedError(closeSSHClient(client)); closeErr == nil {
closeErr = err
}
}
if waitKeepalive && done != nil {
<-done
}
if upstreamErr := closeUpstream(upstream); closeErr == nil {
closeErr = upstreamErr
}
return closeErr
}
func (s *StarSSH) canAttachAgentForwarder(client *ssh.Client) bool {
if s == nil || client == nil || s.closing.Load() {
return false
}
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return !s.closing.Load() && s.Client == client
}
+431
View File
@@ -0,0 +1,431 @@
package starssh
import (
"context"
"errors"
"io"
"sync"
"golang.org/x/crypto/ssh"
)
func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) {
session, err := s.NewPTYSession(config)
if err != nil {
return nil, err
}
stdin, err := session.StdinPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stdout, err := session.StdoutPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stderr, err := session.StderrPipe()
if err != nil {
_ = session.Close()
return nil, err
}
if err := session.Shell(); err != nil {
_ = session.Close()
return nil, err
}
return &TerminalSession{
Session: session,
stdin: stdin,
stdout: stdout,
stderr: stderr,
runDone: make(chan struct{}),
waitDone: make(chan struct{}),
}, nil
}
func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) {
if t == nil {
return
}
t.attachMu.Lock()
defer t.attachMu.Unlock()
t.in = stdin
t.out = stdout
t.errOut = stderr
}
func (t *TerminalSession) StdinWriter() io.Writer {
if t == nil {
return nil
}
return t.stdin
}
func (t *TerminalSession) StdoutReader() io.Reader {
if t == nil {
return nil
}
return t.stdout
}
func (t *TerminalSession) StderrReader() io.Reader {
if t == nil {
return nil
}
return t.stderr
}
func (t *TerminalSession) Write(data []byte) (int, error) {
if t == nil || t.stdin == nil {
return 0, errors.New("terminal stdin is not initialized")
}
return t.stdin.Write(data)
}
func (t *TerminalSession) SendControl(control TerminalControl) error {
_, err := t.Write([]byte{byte(control)})
return err
}
func (t *TerminalSession) Interrupt() error {
return t.SendControl(TerminalControlInterrupt)
}
func (t *TerminalSession) Signal(sig ssh.Signal) error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
if sig == "" {
return errors.New("signal is empty")
}
return t.Session.Signal(sig)
}
func (t *TerminalSession) Run(ctx context.Context) error {
if t == nil {
return errors.New("terminal session is nil")
}
if ctx == nil {
ctx = context.Background()
}
t.runOnce.Do(func() {
t.runErr = t.run(ctx)
close(t.runDone)
})
<-t.runDone
return t.runErr
}
func (t *TerminalSession) run(ctx context.Context) error {
if t.Session == nil {
return errors.New("terminal session is not initialized")
}
t.attachMu.RLock()
in := t.in
out := t.out
errOut := t.errOut
t.attachMu.RUnlock()
if out == nil {
out = io.Discard
}
if errOut == nil {
errOut = out
}
inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in)
if err != nil {
return err
}
defer cancelInput()
var copyWG sync.WaitGroup
doneCopy := make(chan struct{})
copyWG.Add(2)
go func() {
defer copyWG.Done()
if t.stdout != nil {
_, _ = io.Copy(out, t.stdout)
}
}()
go func() {
defer copyWG.Done()
if t.stderr != nil {
_, _ = io.Copy(errOut, t.stderr)
}
}()
go func() {
copyWG.Wait()
close(doneCopy)
}()
var doneInput chan struct{}
if inputReader != nil && t.stdin != nil {
doneInput = make(chan struct{})
go func() {
defer close(doneInput)
_, _ = io.Copy(t.stdin, inputReader)
_ = t.stdin.Close()
}()
}
waitInputPump := func() {
if doneInput == nil {
return
}
select {
case <-doneInput:
return
default:
}
if inputCancelable {
cancelInput()
<-doneInput
}
}
type waitResult struct {
info TerminalExitInfo
err error
}
waitCh := make(chan waitResult, 1)
go func() {
info, err := t.WaitResult()
waitCh <- waitResult{info: info, err: err}
}()
select {
case result := <-waitCh:
waitInputPump()
<-doneCopy
if result.err != nil {
return result.err
}
return result.info.CommandError()
case <-ctx.Done():
t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err())
cancelInput()
_ = t.Close()
<-waitCh
waitInputPump()
<-doneCopy
return ctx.Err()
}
}
func (t *TerminalSession) Wait() error {
info, err := t.WaitResult()
if err != nil {
return err
}
return info.CommandError()
}
func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) {
waitErr := t.waitRaw()
info, closeErr := t.snapshotExitState()
if closeErr != nil {
return info, closeErr
}
if waitErr == nil {
return info, nil
}
var exitErr *ssh.ExitError
if errors.As(waitErr, &exitErr) {
return info, nil
}
if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed {
return info, nil
}
return info, waitErr
}
func (t *TerminalSession) ExitInfo() TerminalExitInfo {
if t == nil {
return TerminalExitInfo{}
}
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.exitInfo
}
func (info TerminalExitInfo) Success() bool {
return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == ""
}
func (info TerminalExitInfo) CommandError() error {
if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal {
return nil
}
if info.ExitCode == 0 && info.ExitSignal == "" {
return nil
}
return &ExecExitError{
Status: info.ExitCode,
Signal: info.ExitSignal,
Message: info.ExitMessage,
}
}
func (t *TerminalSession) Resize(columns int, rows int) error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
if columns <= 0 || rows <= 0 {
return errors.New("columns and rows must be > 0")
}
return t.Session.WindowChange(rows, columns)
}
func (t *TerminalSession) Close() error {
if t == nil {
return nil
}
var closeErr error
t.closeOnce.Do(func() {
if t.stdin != nil {
_ = t.stdin.Close()
}
if t.Session != nil {
closeErr = normalizeAlreadyClosedError(t.Session.Close())
}
})
if closeErr != nil {
t.markCloseReason(TerminalCloseReasonTransportError, closeErr)
}
return closeErr
}
func (t *TerminalSession) waitRaw() error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
t.waitOnce.Do(func() {
go func() {
waitErr := t.Session.Wait()
t.setWaitResult(waitErr)
close(t.waitDone)
}()
})
<-t.waitDone
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.waitErr
}
func (t *TerminalSession) setWaitResult(waitErr error) {
if t == nil {
return
}
t.stateMu.Lock()
defer t.stateMu.Unlock()
t.waitErr = waitErr
t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason)
}
func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) {
if t == nil || reason == TerminalCloseReasonUnknown {
return
}
t.stateMu.Lock()
defer t.stateMu.Unlock()
if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) {
t.closeReason = reason
}
if err != nil && t.closeErr == nil {
t.closeErr = err
}
}
func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) {
if t == nil {
return TerminalExitInfo{}, nil
}
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.exitInfo, t.closeErr
}
func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo {
info := TerminalExitInfo{}
if waitErr == nil {
info.Reason = TerminalCloseReasonExit
} else {
var exitErr *ssh.ExitError
switch {
case errors.As(waitErr, &exitErr):
info.ExitCode = exitErr.ExitStatus()
info.ExitSignal = exitErr.Signal()
info.ExitMessage = exitErr.Msg()
if info.ExitSignal != "" {
info.Reason = TerminalCloseReasonSignal
} else {
info.Reason = TerminalCloseReasonExit
}
case normalizeAlreadyClosedError(waitErr) == nil:
info.Reason = TerminalCloseReasonClosed
case errors.Is(waitErr, context.Canceled):
info.Reason = TerminalCloseReasonContextCanceled
case errors.Is(waitErr, context.DeadlineExceeded):
info.Reason = TerminalCloseReasonDeadlineExceeded
default:
info.Reason = TerminalCloseReasonTransportError
}
}
if overrideReason != TerminalCloseReasonUnknown &&
terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) {
info.Reason = overrideReason
}
return info
}
func terminalCloseReasonFromErr(err error) TerminalCloseReason {
switch {
case errors.Is(err, context.DeadlineExceeded):
return TerminalCloseReasonDeadlineExceeded
case errors.Is(err, context.Canceled):
return TerminalCloseReasonContextCanceled
case err != nil:
return TerminalCloseReasonTransportError
default:
return TerminalCloseReasonUnknown
}
}
func terminalCloseReasonPriority(reason TerminalCloseReason) int {
switch reason {
case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded:
return 30
case TerminalCloseReasonTransportError:
return 20
case TerminalCloseReasonClosed:
return 10
case TerminalCloseReasonSignal, TerminalCloseReasonExit:
return 5
default:
return 0
}
}
+225
View File
@@ -0,0 +1,225 @@
package starssh
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"reflect"
"sync"
"unsafe"
)
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
// Implementations that buffer data should return a source that already includes any prefetched bytes.
type TerminalInputSourceProvider interface {
TerminalInputSource() io.Reader
}
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
type TerminalInputCanceler interface {
TerminalInputCancel() error
}
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
type TerminalInputAdapter struct {
Reader io.Reader
Source io.Reader
Cancel func() error
}
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
if a.Reader == nil {
return 0, io.EOF
}
return a.Reader.Read(p)
}
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
if a.Source != nil {
return a.Source
}
return a.Reader
}
func (a TerminalInputAdapter) TerminalInputCancel() error {
if a.Cancel != nil {
return a.Cancel()
}
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
return closer.Close()
}
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
return closer.Close()
}
return nil
}
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
var cancelOnce sync.Once
wrapCancel := func(fn func()) func() {
return func() {
cancelOnce.Do(fn)
}
}
if provider, ok := in.(TerminalInputSourceProvider); ok {
source := provider.TerminalInputSource()
if source == nil || sameReader(source, in) {
return prepareDirectTerminalInputReader(in, wrapCancel)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
if err != nil {
return nil, nil, false, err
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return prepared, wrapCancel(func() {
cancel()
_ = canceler.TerminalInputCancel()
}), true, nil
}
return prepared, cancel, cancelable, nil
}
return prepareDirectTerminalInputReader(in, wrapCancel)
}
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
switch typed := in.(type) {
case *bufio.Reader:
return prepareBufferedTerminalInputReader(typed)
case *bufio.ReadWriter:
if typed.Reader == nil {
return in, func() {}, false, nil
}
return prepareBufferedTerminalInputReader(typed.Reader)
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return in, wrapCancel(func() {
_ = canceler.TerminalInputCancel()
}), true, nil
}
if file, ok := in.(*os.File); ok {
dup, err := duplicateTerminalInputFile(file)
if err != nil {
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
}
return dup, wrapCancel(func() {
_ = dup.Close()
}), true, nil
}
if closer, ok := in.(io.ReadCloser); ok {
return closer, wrapCancel(func() {
_ = closer.Close()
}), true, nil
}
return in, func() {}, false, nil
}
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
if reader == nil {
return nil, func() {}, false, nil
}
bufferedPrefix, err := snapshotBufferedPrefix(reader)
if err != nil {
return nil, nil, false, err
}
underlying := unwrapBufioReader(reader)
if underlying == nil {
if len(bufferedPrefix) == 0 {
return reader, func() {}, false, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
if err != nil {
return nil, nil, false, err
}
if len(bufferedPrefix) == 0 {
return prepared, cancel, cancelable, nil
}
if prepared == nil {
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
}
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
buffered := reader.Buffered()
if buffered == 0 {
return nil, nil
}
chunk, err := reader.Peek(buffered)
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
}
prefix := append([]byte(nil), chunk...)
if _, err := reader.Discard(len(prefix)); err != nil {
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
}
return prefix, nil
}
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
if reader == nil {
return nil
}
value := reflect.ValueOf(reader)
if value.Kind() != reflect.Pointer || value.IsNil() {
return nil
}
field := value.Elem().FieldByName("rd")
if !field.IsValid() {
return nil
}
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
underlying, ok := underlyingValue.Interface().(io.Reader)
if !ok || underlying == nil || sameReader(underlying, reader) {
return nil
}
return underlying
}
func sameReader(left io.Reader, right io.Reader) bool {
if left == nil || right == nil {
return false
}
leftValue := reflect.ValueOf(left)
rightValue := reflect.ValueOf(right)
if !leftValue.IsValid() || !rightValue.IsValid() {
return false
}
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
return false
}
return leftValue.Pointer() == rightValue.Pointer()
}
+49
View File
@@ -0,0 +1,49 @@
package starssh
import (
"bufio"
"io"
"testing"
"time"
)
func TestPrepareTerminalInputReaderAdapterCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
adapter := TerminalInputAdapter{
Reader: bufio.NewReader(reader),
Source: reader,
Cancel: func() error {
return reader.CloseWithError(io.ErrClosedPipe)
},
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(adapter)
if err != nil {
t.Fatalf("prepare adapter input: %v", err)
}
if !cancelable {
t.Fatal("expected adapter-backed input to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected adapter cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking adapter input did not unblock after cancel")
}
}
+290
View File
@@ -0,0 +1,290 @@
package starssh
import (
"bufio"
"bytes"
"io"
"os"
"testing"
"time"
)
type terminalInputProvider struct {
io.Reader
source io.Reader
}
func (p terminalInputProvider) TerminalInputSource() io.Reader {
return p.source
}
type prefixedReadCloser struct {
io.Reader
io.Closer
}
func TestPrepareTerminalInputReaderBufioReaderPreservesBufferedBytes(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("hello world")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
buffered := bufio.NewReaderSize(reader, 4)
peeked, err := buffered.Peek(5)
if err != nil {
t.Fatalf("prime buffer: %v", err)
}
if string(peeked) != "hello" {
t.Fatalf("unexpected buffered prefix: %q", string(peeked))
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected cancelable reader")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read prepared input: %v", err)
}
if string(data) != "hello world" {
t.Fatalf("unexpected prepared input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderBufioReadWriterPreservesBufferedBytes(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("buffered payload")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
readWriter := bufio.NewReadWriter(bufio.NewReaderSize(reader, 8), bufio.NewWriter(io.Discard))
if _, err := readWriter.Reader.Peek(8); err != nil {
t.Fatalf("prime readwriter buffer: %v", err)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
if err != nil {
t.Fatalf("prepare readwriter input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected cancelable readwriter input")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read prepared readwriter input: %v", err)
}
if string(data) != "buffered payload" {
t.Fatalf("unexpected prepared readwriter input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderBufioReaderFallbackKeepsData(t *testing.T) {
buffered := bufio.NewReader(bytes.NewBufferString("abc123"))
if _, err := buffered.Peek(3); err != nil {
t.Fatalf("prime buffer: %v", err)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare fallback input: %v", err)
}
defer cancel()
if cancelable {
t.Fatal("expected non-cancelable reader")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read fallback input: %v", err)
}
if string(data) != "abc123" {
t.Fatalf("unexpected fallback input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderProviderPrefersExplicitSource(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("provider data")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
buffered := bufio.NewReader(reader)
prefix, err := buffered.Peek(len("provider data"))
if err != nil {
t.Fatalf("prime buffer: %v", err)
}
source := prefixedReadCloser{
Reader: io.MultiReader(bytes.NewReader(append([]byte(nil), prefix...)), reader),
Closer: reader,
}
provider := terminalInputProvider{
Reader: buffered,
source: source,
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
if err != nil {
t.Fatalf("prepare provider input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected provider-backed input to be cancelable")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read provider input: %v", err)
}
if string(data) != "provider data" {
t.Fatalf("unexpected provider input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderProviderCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
buffered := bufio.NewReader(reader)
provider := terminalInputProvider{
Reader: buffered,
source: reader,
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
if !cancelable {
t.Fatal("expected provider-backed input to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking read did not unblock after cancel")
}
}
func TestPrepareTerminalInputReaderBufioReaderCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
buffered := bufio.NewReader(reader)
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
if !cancelable {
t.Fatal("expected bufio reader to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking bufio reader did not unblock after cancel")
}
}
func TestPrepareTerminalInputReaderBufioReadWriterCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
readWriter := bufio.NewReadWriter(bufio.NewReader(reader), bufio.NewWriter(io.Discard))
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
if err != nil {
t.Fatalf("prepare readwriter input: %v", err)
}
if !cancelable {
t.Fatal("expected bufio readwriter to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking readwriter input did not unblock after cancel")
}
}
+21
View File
@@ -0,0 +1,21 @@
//go:build !windows
package starssh
import (
"os"
"syscall"
)
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
if file == nil {
return nil, os.ErrInvalid
}
fd, err := syscall.Dup(int(file.Fd()))
if err != nil {
return nil, err
}
syscall.CloseOnExec(fd)
return os.NewFile(uintptr(fd), file.Name()), nil
}
+32
View File
@@ -0,0 +1,32 @@
//go:build windows
package starssh
import (
"os"
"golang.org/x/sys/windows"
)
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
if file == nil {
return nil, os.ErrInvalid
}
currentProcess := windows.CurrentProcess()
var duplicated windows.Handle
err := windows.DuplicateHandle(
currentProcess,
windows.Handle(file.Fd()),
currentProcess,
&duplicated,
0,
false,
windows.DUPLICATE_SAME_ACCESS,
)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(duplicated), file.Name()), nil
}
+336
View File
@@ -0,0 +1,336 @@
package starssh
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
if c == nil || c.reader == nil {
return 0, io.EOF
}
return c.reader.Read(p)
}
func resolveDialContext(info LoginInput) DialContextFunc {
if info.DialContext != nil {
return info.DialContext
}
dialer := &net.Dialer{
Timeout: effectiveDialTimeout(info),
}
return dialer.DialContext
}
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
targetAddr := joinHostPort(info.Addr, info.Port)
if info.Jump != nil {
return dialViaJump(ctx, info, targetAddr)
}
dialContext := resolveDialContext(info)
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
if proxyConfig != nil {
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
}
conn, err := dialContext(ctx, "tcp", targetAddr)
return conn, nil, err
}
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
if info.Jump == nil {
return nil, nil, errors.New("jump login info is nil")
}
jumpClient, err := loginWithContext(ctx, *info.Jump)
if err != nil {
return nil, nil, err
}
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
if err != nil {
_ = jumpClient.Close()
return nil, nil, err
}
return conn, jumpClient, nil
}
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
if dialContext == nil {
return nil, nil, errors.New("dial context is nil")
}
if strings.TrimSpace(proxy.Addr) == "" {
return nil, nil, errors.New("proxy address is empty")
}
switch proxy.Type {
case ProxyTypeSOCKS5:
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
case ProxyTypeHTTPConnect:
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
default:
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
}
}
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
if proxy.Username != "" || proxy.Password != "" {
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
request += "Proxy-Authorization: Basic " + token + "\r\n"
}
request += "\r\n"
if _, err := io.WriteString(conn, request); err != nil {
_ = conn.Close()
return nil, err
}
reader := bufio.NewReader(conn)
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = conn.Close()
return nil, err
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 300 {
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
_ = conn.Close()
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
}
if reader.Buffered() == 0 {
return conn, nil
}
return &bufferedConn{Conn: conn, reader: reader}, nil
}
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
methods := []byte{0x00}
useAuth := proxy.Username != "" || proxy.Password != ""
if useAuth {
methods = append(methods, 0x02)
}
hello := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(hello); err != nil {
_ = conn.Close()
return nil, err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
_ = conn.Close()
return nil, err
}
if response[0] != 0x05 {
_ = conn.Close()
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
}
if response[1] == 0xFF {
_ = conn.Close()
return nil, errors.New("socks5 proxy has no acceptable auth method")
}
if response[1] == 0x02 {
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
_ = conn.Close()
return nil, err
}
}
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
_ = conn.Close()
return nil, err
}
if err := readSOCKS5ConnectResponse(conn); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
if len(username) > 255 || len(password) > 255 {
return errors.New("socks5 username/password too long")
}
request := make([]byte, 0, 3+len(username)+len(password))
request = append(request, 0x01, byte(len(username)))
request = append(request, []byte(username)...)
request = append(request, byte(len(password)))
request = append(request, []byte(password)...)
if _, err := conn.Write(request); err != nil {
return err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
return err
}
if response[1] != 0x00 {
return errors.New("socks5 username/password authentication failed")
}
return nil
}
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
host, portString, err := net.SplitHostPort(targetAddr)
if err != nil {
return err
}
port, err := strconv.Atoi(portString)
if err != nil {
return err
}
if port < 0 || port > 65535 {
return fmt.Errorf("invalid port %d", port)
}
request := []byte{0x05, 0x01, 0x00}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
request = append(request, 0x01)
request = append(request, ip4...)
} else {
request = append(request, 0x04)
request = append(request, ip.To16()...)
}
} else {
if len(host) > 255 {
return errors.New("socks5 target host too long")
}
request = append(request, 0x03, byte(len(host)))
request = append(request, []byte(host)...)
}
request = append(request, byte(port>>8), byte(port))
_, err = conn.Write(request)
return err
}
func readSOCKS5ConnectResponse(conn net.Conn) error {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
if header[0] != 0x05 {
return fmt.Errorf("invalid socks5 response version %d", header[0])
}
if header[1] != 0x00 {
return fmt.Errorf("socks5 connect failed with code %d", header[1])
}
switch header[3] {
case 0x01:
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
return err
}
case 0x03:
size := make([]byte, 1)
if _, err := io.ReadFull(conn, size); err != nil {
return err
}
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
return err
}
case 0x04:
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
}
_, err := io.ReadFull(conn, make([]byte, 2))
return err
}
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
if conn == nil {
return func() {}
}
var (
deadline time.Time
hasValue bool
)
if ctx != nil {
if ctxDeadline, ok := ctx.Deadline(); ok {
deadline = ctxDeadline
hasValue = true
}
}
if timeout > 0 {
timeoutDeadline := time.Now().Add(timeout)
if !hasValue || timeoutDeadline.Before(deadline) {
deadline = timeoutDeadline
hasValue = true
}
}
if !hasValue {
return func() {}
}
_ = conn.SetDeadline(deadline)
return func() {
_ = conn.SetDeadline(time.Time{})
}
}
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
if proxy == nil {
return nil
}
normalized := *proxy
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
normalized.Addr = strings.TrimSpace(normalized.Addr)
if normalized.Timeout <= 0 {
normalized.Timeout = defaultTimeout
}
return &normalized
}
func joinHostPort(host string, port int) string {
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
}
+237
View File
@@ -0,0 +1,237 @@
package starssh
import (
"bufio"
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
defaultSSHPort = 22
defaultLoginTimeout = 5 * time.Second
defaultSSHAgentTimeout = 2 * time.Minute
defaultKeepAliveTimeout = 3 * time.Second
defaultShellPollInterval = 120 * time.Millisecond
defaultShellSetupDelay = 200 * time.Millisecond
defaultShellSetupTimeout = 3 * time.Second
defaultShellWaitTimeout = 30 * time.Second
defaultShellPromptToken = "__STARSSH_PROMPT__>"
defaultPTYTerm = "xterm"
defaultPTYRows = 500
defaultPTYColumns = 250
defaultTransferBufferSize = 1024 * 1024
defaultExecStreamMaxPendingChunks = 256
defaultExecStreamMaxPendingBytes = 4 * 1024 * 1024
)
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
type ProxyType string
const (
ProxyTypeSOCKS5 ProxyType = "socks5"
ProxyTypeHTTPConnect ProxyType = "http_connect"
)
type ProxyConfig struct {
Type ProxyType
Addr string
Username string
Password string
Timeout time.Duration
}
type AuthMethodKind string
const (
AuthMethodPrivateKey AuthMethodKind = "private_key"
AuthMethodPassword AuthMethodKind = "password"
AuthMethodKeyboardInteractive AuthMethodKind = "keyboard_interactive"
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
)
type SSHAgentDebugFunc func(SSHAgentDebugEvent)
type SSHAgentDebugEvent struct {
Step string
Source string
Endpoint string
Network string
Phase string
Status string
Duration time.Duration
KeyCount int
Err error
}
type StarSSH struct {
stateMu sync.RWMutex
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
upstream *StarSSH
sftpClient *sftp.Client
sftpMu sync.Mutex
agentForwardMu sync.Mutex
agentForwarder io.Closer
keepaliveMu sync.Mutex
keepaliveStop chan struct{}
keepaliveDone chan struct{}
closing atomic.Bool
}
type LoginInput struct {
KeyExchanges []string
Ciphers []string
MACs []string
User string
Password string
PasswordCallback func() (string, error)
KeyboardInteractiveCallback ssh.KeyboardInteractiveChallenge
Prikey string
PrikeyPwd string
DisableSSHAgent bool
ForwardSSHAgent bool
AuthOrder []AuthMethodKind
// IdentityAgent overrides the local ssh-agent endpoint used for authentication
// and agent forwarding. Empty uses SSH_AUTH_SOCK, or the platform default where
// one exists.
IdentityAgent string
Addr string
Port int
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
// already been established. Zero means no authentication timeout.
Timeout time.Duration
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
// uses the package default dial timeout. Negative disables the default dial timeout.
DialTimeout time.Duration
// SSHAgentTimeout limits ssh-agent protocol operations such as listing keys and
// signing challenges. Zero uses the package default, and negative disables the
// per-operation deadline. This is intentionally separate from Timeout and
// DialTimeout because hardware-backed agents may require a PIN or touch confirmation.
SSHAgentTimeout time.Duration
// SSHAgentForwardTimeout limits idle reads and writes on forwarded agent
// channels. Zero or negative leaves forwarded channels without an idle deadline.
SSHAgentForwardTimeout time.Duration
// SSHAgentDebug receives structured ssh-agent dial/protocol events. It is nil by
// default and must not log private key material.
SSHAgentDebug SSHAgentDebugFunc
DialContext DialContextFunc
Proxy *ProxyConfig
Jump *LoginInput
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
BannerCallback func(string) error
}
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
// It is not a generic cross-shell abstraction; for product-grade interactive terminals, prefer TerminalSession.
type StarShell struct {
Keyword string
UseWaitDefault bool
WaitTimeout time.Duration
Session *ssh.Session
in io.Writer
out *bufio.Reader
er *bufio.Reader
outbyte []byte
errbyte []byte
lastout int64
errors error
isprint bool
isfuncs bool
iscolor bool
isecho bool
rw sync.RWMutex
funcs func(string)
writeMu sync.Mutex
commandMu sync.Mutex
closeOnce sync.Once
promptToken string
}
type TerminalConfig struct {
Term string
Rows int
Columns int
Modes ssh.TerminalModes
}
type TerminalControl byte
const (
TerminalControlInterrupt TerminalControl = 0x03
TerminalControlEOF TerminalControl = 0x04
TerminalControlBell TerminalControl = 0x07
TerminalControlBackspace TerminalControl = 0x08
TerminalControlLineKill TerminalControl = 0x15
TerminalControlQuit TerminalControl = 0x1c
TerminalControlSuspend TerminalControl = 0x1a
TerminalControlPauseOutput TerminalControl = 0x13
TerminalControlResumeOutput TerminalControl = 0x11
)
type TerminalCloseReason string
const (
TerminalCloseReasonUnknown TerminalCloseReason = ""
TerminalCloseReasonExit TerminalCloseReason = "exit"
TerminalCloseReasonSignal TerminalCloseReason = "signal"
TerminalCloseReasonClosed TerminalCloseReason = "closed"
TerminalCloseReasonContextCanceled TerminalCloseReason = "context_canceled"
TerminalCloseReasonDeadlineExceeded TerminalCloseReason = "deadline_exceeded"
TerminalCloseReasonTransportError TerminalCloseReason = "transport_error"
)
type TerminalExitInfo struct {
ExitCode int
ExitSignal string
ExitMessage string
Reason TerminalCloseReason
}
type TerminalSession struct {
Session *ssh.Session
ID string
Label string
Metadata map[string]string
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
attachMu sync.RWMutex
in io.Reader
out io.Writer
errOut io.Writer
runOnce sync.Once
runDone chan struct{}
runErr error
waitOnce sync.Once
waitDone chan struct{}
stateMu sync.RWMutex
waitErr error
exitInfo TerminalExitInfo
closeReason TerminalCloseReason
closeErr error
closeOnce sync.Once
}
+162
View File
@@ -0,0 +1,162 @@
package starssh
import (
"crypto/rand"
"encoding/hex"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
var (
ansiCSIRegexp = regexp.MustCompile(`\x1b\[[0-9;?]*[ -/]*[@-~]`)
ansiOSCRegexp = regexp.MustCompile(`\x1b\][^\x07]*(\x07|\x1b\\)`)
leadingIntRegexp = regexp.MustCompile(`^[+-]?\d+`)
)
func SedColor(str string) string {
return stripControlSequences(str)
}
func normalizeShellOutput(raw string) string {
return strings.TrimSpace(strings.ReplaceAll(raw, "\r\n", "\n"))
}
func stripControlSequences(raw string) string {
cleaned := ansiOSCRegexp.ReplaceAllString(raw, "")
cleaned = ansiCSIRegexp.ReplaceAllString(cleaned, "")
cleaned = strings.ReplaceAll(cleaned, "\r", "")
cleaned = strings.Map(func(r rune) rune {
if r == '\n' || r == '\t' {
return r
}
if r < 0x20 || r == 0x7f {
return -1
}
return r
}, cleaned)
return cleaned
}
func stripLeadingEcho(output string, command string, markerCommand string) string {
result := output
if command == "" {
return result
}
withMarker := command
if markerCommand != "" {
withMarker += "\n" + markerCommand
}
if strings.HasPrefix(result, withMarker) {
return strings.TrimSpace(strings.TrimPrefix(result, withMarker))
}
if strings.HasPrefix(result, command) {
return strings.TrimSpace(strings.TrimPrefix(result, command))
}
return result
}
func collectLinesWithoutTokens(output string, tokens ...string) string {
lines := strings.Split(output, "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
skip := false
for _, token := range tokens {
if token != "" && strings.Contains(line, token) {
skip = true
break
}
}
if !skip {
filtered = append(filtered, line)
}
}
return strings.TrimSpace(strings.Join(filtered, "\n"))
}
func collectLinesForCommandOutput(output string, promptToken string, tokens ...string) string {
lines := strings.Split(output, "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
skip := false
for _, token := range tokens {
if token != "" && strings.Contains(line, token) {
skip = true
break
}
}
if !skip && promptToken != "" && strings.Contains(line, promptToken) {
skip = true
}
if !skip {
filtered = append(filtered, line)
}
}
return strings.TrimSpace(strings.Join(filtered, "\n"))
}
func shellSingleQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
func normalizeBufferSize(bufcap int) int {
if bufcap <= 0 {
return defaultTransferBufferSize
}
return bufcap
}
func newCommandTokens() (beginToken string, endToken string) {
nonce := newNonce(8)
return "__STARSSH_BEGIN_" + nonce + "__", "__STARSSH_END_" + nonce + "__"
}
func newNonce(size int) string {
if size <= 0 {
size = 8
}
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return strings.ToUpper(hex.EncodeToString(buf))
}
func splitByEndToken(output string, endToken string) (before string, exitCode int, found bool, parseErr error) {
prefix := endToken + ":"
lines := strings.Split(output, "\n")
beforeLines := make([]string, 0, len(lines))
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if !strings.HasPrefix(trimmedLine, prefix) {
beforeLines = append(beforeLines, line)
continue
}
codeText := strings.TrimSpace(trimmedLine[len(prefix):])
match := leadingIntRegexp.FindString(codeText)
if match == "" || match != codeText {
beforeLines = append(beforeLines, line)
continue
}
code, err := strconv.Atoi(match)
if err != nil {
beforeLines = append(beforeLines, line)
continue
}
return strings.Join(beforeLines, "\n"), code, true, nil
}
return strings.Join(beforeLines, "\n"), 0, false, nil
}