refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力

- 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现
  - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持
  - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行
  - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令
  - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理
  - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景
  - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径
  - 围绕新的 session/连接生命周期重做执行池与运行时支撑
  - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义
  - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
This commit is contained in:
兔子 2026-04-26 10:45:39 +08:00
parent d6fbea8468
commit f20eb653ae
Signed by: b612
GPG Key ID: 99DD2222B612B612
31 changed files with 8538 additions and 769 deletions

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
.sentrux/
agent_readme.md
target.md
.gocache/
.tmp_*/
.codex/
.idea/
agents.md
.codex

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

172
cancel_semantics_test.go Normal file
View File

@ -0,0 +1,172 @@
package starssh
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContext(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after PingContext cancel")
}
}
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContextCloseOnCancel(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
}
}
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
}
}
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
}
}

1218
exec.go Normal file

File diff suppressed because it is too large Load Diff

69
exec_legacy_test.go Normal file
View File

@ -0,0 +1,69 @@
package starssh
import (
"context"
"strings"
"testing"
)
func TestExistsFallsBackFromPOSIXToPowerShell(t *testing.T) {
oldExecRequestRunner := execRequestRunner
t.Cleanup(func() {
execRequestRunner = oldExecRequestRunner
})
var calls []ExecRequest
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
calls = append(calls, req)
switch len(calls) {
case 1:
return &ExecResult{ExitCode: 127, Stderr: []byte("sh not found")}, nil
case 2:
return &ExecResult{Stdout: []byte("1\n")}, nil
default:
t.Fatalf("unexpected extra probe request: %+v", req)
return nil, nil
}
}
star := &StarSSH{}
if !star.Exists(`C:\Windows\System32`) {
t.Fatal("expected helper to succeed after PowerShell fallback")
}
if len(calls) != 2 {
t.Fatalf("expected two probe attempts, got %d", len(calls))
}
if calls[0].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(calls[0].Command, "sh -lc ") {
t.Fatalf("unexpected first probe request: %+v", calls[0])
}
if calls[1].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(strings.ToLower(calls[1].Command), "powershell.exe ") {
t.Fatalf("unexpected second probe request: %+v", calls[1])
}
}
func TestGetUserFallsBackToCMDWhenPowerShellVariantsFail(t *testing.T) {
oldExecRequestRunner := execRequestRunner
t.Cleanup(func() {
execRequestRunner = oldExecRequestRunner
})
var calls []ExecRequest
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
calls = append(calls, req)
if len(calls) < 5 {
return &ExecResult{ExitCode: 127, Stderr: []byte("command not found")}, nil
}
return &ExecResult{Stdout: []byte("HOST\\tester\r\n")}, nil
}
star := &StarSSH{}
if got := star.GetUser(); got != `HOST\tester` {
t.Fatalf("unexpected user after fallback: %q", got)
}
if len(calls) != 5 {
t.Fatalf("expected five probe attempts, got %d", len(calls))
}
if got := strings.ToLower(calls[4].Command); got != "cmd.exe /q /d /c whoami" {
t.Fatalf("unexpected final fallback command: %q", calls[4].Command)
}
}

694
forward.go Normal file
View File

@ -0,0 +1,694 @@
package starssh
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync"
"golang.org/x/crypto/ssh"
)
type ForwardRequest struct {
ListenAddr string
TargetAddr string
DialContext DialContextFunc
}
type DynamicForwardRequest struct {
ListenAddr string
}
type PortForwarder struct {
listener net.Listener
ctx context.Context
cancel context.CancelFunc
acceptDone chan struct{}
connWG sync.WaitGroup
closeOnce sync.Once
connMu sync.Mutex
conns map[net.Conn]struct{}
errMu sync.Mutex
err error
cleanupOnce sync.Once
cleanupFns []func() error
}
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
return client.Dial(network, address)
}
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
if ctx == nil {
ctx = context.Background()
}
return LoginContext(ctx, input)
}
func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) {
return s.DialTCPContext(context.Background(), network, address)
}
func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) {
return s.dialTCPContext(ctx, network, address, nil)
}
func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) {
return s.dialTCPContext(ctx, network, address, s.Close)
}
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
if strings.TrimSpace(network) == "" {
network = "tcp"
}
if strings.TrimSpace(address) == "" {
return nil, errors.New("forward address is empty")
}
type dialResult struct {
conn net.Conn
err error
}
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
runCancel := func() {}
if onCancel != nil {
var cancelOnce sync.Once
runCancel = func() {
cancelOnce.Do(func() {
_ = onCancel()
})
}
cancelDone := make(chan struct{})
defer close(cancelDone)
go func() {
select {
case <-ctx.Done():
runCancel()
case <-cancelDone:
}
}()
}
dialFunc := dialSSHClient
resultCh := make(chan dialResult, 1)
go func() {
conn, err := dialFunc(ctx, client, network, address)
if ctx.Err() != nil && conn != nil {
_ = conn.Close()
conn = nil
}
select {
case resultCh <- dialResult{conn: conn, err: err}:
default:
if conn != nil {
_ = conn.Close()
}
}
}()
select {
case result := <-resultCh:
return result.conn, result.err
case <-ctx.Done():
runCancel()
return nil, ctx.Err()
}
}
func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("remote forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("remote forward target address is empty")
}
listener, err := client.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
dialContext := req.DialContext
if dialContext == nil {
dialer := &net.Dialer{
Timeout: defaultLoginTimeout,
}
dialContext = dialer.DialContext
}
forwarder := newPortForwarder(listener)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return dialContext(ctx, "tcp", req.TargetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("dynamic forward listen address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
return s.DialTCPContext(ctx, "tcp", targetAddr)
})
return forwarder, nil
}
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("dynamic forward listen address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
if err != nil {
return nil, err
}
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, "tcp", targetAddr)
})
return forwarder, nil
}
func (f *PortForwarder) Addr() net.Addr {
if f == nil || f.listener == nil {
return nil
}
return f.listener.Addr()
}
func (f *PortForwarder) Wait() error {
if f == nil {
return nil
}
<-f.acceptDone
f.connWG.Wait()
f.runCleanup()
return f.Err()
}
func (f *PortForwarder) Err() error {
if f == nil {
return nil
}
f.errMu.Lock()
defer f.errMu.Unlock()
return f.err
}
func (f *PortForwarder) Close() error {
if f == nil {
return nil
}
var closeErr error
f.closeOnce.Do(func() {
if f.cancel != nil {
f.cancel()
}
if f.listener != nil {
closeErr = normalizeAlreadyClosedError(f.listener.Close())
}
f.closeActiveConnections()
})
<-f.acceptDone
f.connWG.Wait()
f.runCleanup()
if closeErr != nil {
return closeErr
}
return f.Err()
}
func newPortForwarder(listener net.Listener) *PortForwarder {
ctx, cancel := context.WithCancel(context.Background())
return &PortForwarder{
listener: listener,
ctx: ctx,
cancel: cancel,
acceptDone: make(chan struct{}),
conns: make(map[net.Conn]struct{}),
}
}
func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) {
if s == nil {
return nil, errors.New("ssh client is nil")
}
return newDetachedForwardClient(ctx, s.LoginInfo)
}
func (f *PortForwarder) addCleanup(fn func() error) {
if f == nil || fn == nil {
return
}
f.cleanupFns = append(f.cleanupFns, fn)
}
func (f *PortForwarder) runCleanup() {
if f == nil {
return
}
f.cleanupOnce.Do(func() {
for _, fn := range f.cleanupFns {
f.setError(normalizeAlreadyClosedError(fn()))
}
})
}
func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) {
go func() {
defer close(f.acceptDone)
for {
conn, err := f.listener.Accept()
if err != nil {
if isClosedListenerError(err) {
return
}
f.setError(err)
return
}
f.trackConn(conn)
f.connWG.Add(1)
go func(src net.Conn) {
defer f.connWG.Done()
defer f.untrackConn(src)
dst, err := targetDial(f.ctx)
if err != nil {
f.setError(err)
_ = src.Close()
return
}
f.trackConn(dst)
defer f.untrackConn(dst)
f.setError(pipeForwardConnections(src, dst))
}(conn)
}
}()
}
func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) {
go func() {
defer close(f.acceptDone)
for {
conn, err := f.listener.Accept()
if err != nil {
if isClosedListenerError(err) {
return
}
f.setError(err)
return
}
f.trackConn(conn)
f.connWG.Add(1)
go func(src net.Conn) {
defer f.connWG.Done()
defer f.untrackConn(src)
if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil {
f.setError(err)
}
}(conn)
}
}()
}
func (f *PortForwarder) setError(err error) {
if f == nil || err == nil || f.shouldIgnoreError(err) {
return
}
f.errMu.Lock()
defer f.errMu.Unlock()
if f.err == nil {
f.err = err
}
}
func pipeForwardConnections(left net.Conn, right net.Conn) error {
if left == nil || right == nil {
if left != nil {
_ = left.Close()
}
if right != nil {
_ = right.Close()
}
return errors.New("forward connection endpoint is nil")
}
defer left.Close()
defer right.Close()
var copyWG sync.WaitGroup
errCh := make(chan error, 2)
copyWG.Add(2)
go func() {
defer copyWG.Done()
_, err := io.Copy(right, left)
errCh <- normalizeAlreadyClosedError(err)
closeWrite(right)
}()
go func() {
defer copyWG.Done()
_, err := io.Copy(left, right)
errCh <- normalizeAlreadyClosedError(err)
closeWrite(left)
}()
copyWG.Wait()
close(errCh)
for err := range errCh {
if err != nil {
return err
}
}
return nil
}
func closeWrite(conn net.Conn) {
if conn == nil {
return
}
type closeWriter interface {
CloseWrite() error
}
if writer, ok := conn.(closeWriter); ok {
_ = writer.CloseWrite()
}
}
func handleDynamicForwardConn(
ctx context.Context,
src net.Conn,
targetDial func(context.Context, string) (net.Conn, error),
trackConn func(net.Conn),
untrackConn func(net.Conn),
) error {
if src == nil {
return errors.New("dynamic forward source connection is nil")
}
defer src.Close()
if err := negotiateSOCKS5NoAuth(src); err != nil {
return err
}
targetAddr, replyCode, err := readSOCKS5ConnectTarget(src)
if err != nil {
_ = writeSOCKS5ServerReply(src, replyCode, nil)
return err
}
dst, err := targetDial(ctx, targetAddr)
if err != nil {
_ = writeSOCKS5ServerReply(src, 0x01, nil)
return err
}
if trackConn != nil {
trackConn(dst)
defer untrackConn(dst)
}
if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil {
_ = dst.Close()
return err
}
return pipeForwardConnections(src, dst)
}
func (f *PortForwarder) trackConn(conn net.Conn) {
if f == nil || conn == nil {
return
}
f.connMu.Lock()
defer f.connMu.Unlock()
f.conns[conn] = struct{}{}
}
func (f *PortForwarder) untrackConn(conn net.Conn) {
if f == nil || conn == nil {
return
}
f.connMu.Lock()
defer f.connMu.Unlock()
delete(f.conns, conn)
}
func (f *PortForwarder) closeActiveConnections() {
if f == nil {
return
}
f.connMu.Lock()
conns := make([]net.Conn, 0, len(f.conns))
for conn := range f.conns {
conns = append(conns, conn)
}
f.connMu.Unlock()
for _, conn := range conns {
_ = conn.Close()
}
}
func (f *PortForwarder) shouldIgnoreError(err error) bool {
if err == nil {
return true
}
if normalizeAlreadyClosedError(err) == nil {
return true
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
return false
}
func negotiateSOCKS5NoAuth(conn net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
if header[0] != 0x05 {
return errors.New("invalid socks5 version")
}
methodCount := int(header[1])
methods := make([]byte, methodCount)
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
method := byte(0xFF)
for _, candidate := range methods {
if candidate == 0x00 {
method = 0x00
break
}
}
if _, err := conn.Write([]byte{0x05, method}); err != nil {
return err
}
if method == 0xFF {
return errors.New("socks5 client does not support no-auth method")
}
return nil
}
func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0x01, err
}
if header[0] != 0x05 {
return "", 0x01, errors.New("invalid socks5 request version")
}
if header[1] != 0x01 {
return "", 0x07, errors.New("unsupported socks5 command")
}
host, err := readSOCKS5RequestHost(conn, header[3])
if err != nil {
return "", 0x08, err
}
portBytes := make([]byte, 2)
if _, err := io.ReadFull(conn, portBytes); err != nil {
return "", 0x01, err
}
port := int(portBytes[0])<<8 | int(portBytes[1])
return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil
}
func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) {
switch addressType {
case 0x01:
buffer := make([]byte, 4)
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return net.IP(buffer).String(), nil
case 0x03:
size := make([]byte, 1)
if _, err := io.ReadFull(conn, size); err != nil {
return "", err
}
buffer := make([]byte, int(size[0]))
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return string(buffer), nil
case 0x04:
buffer := make([]byte, 16)
if _, err := io.ReadFull(conn, buffer); err != nil {
return "", err
}
return net.IP(buffer).String(), nil
default:
return "", errors.New("unsupported socks5 address type")
}
}
func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error {
reply := []byte{0x05, replyCode, 0x00}
if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil {
if ip4 := tcpAddr.IP.To4(); ip4 != nil {
reply = append(reply, 0x01)
reply = append(reply, ip4...)
} else if ip16 := tcpAddr.IP.To16(); ip16 != nil {
reply = append(reply, 0x04)
reply = append(reply, ip16...)
}
if len(reply) > 3 {
port := tcpAddr.Port
reply = append(reply, byte(port>>8), byte(port))
_, err := conn.Write(reply)
return err
}
}
reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
_, err := conn.Write(reply)
return err
}
func isClosedListenerError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) {
return true
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), "use of closed network connection")
}

164
forward_test.go Normal file
View File

@ -0,0 +1,164 @@
package starssh
import (
"context"
"io"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return nil, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
closeSSHClient = func(client sshClientRequester) error {
t.Fatal("default local forward should not close the main ssh client")
return nil
}
forwarder, err := star.StartLocalForward(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start local forward: %v", err)
}
defer forwarder.Close()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
if string(reply) != "ping" {
t.Fatalf("unexpected forwarded reply: %q", string(reply))
}
if detachedCalls.Load() != 0 {
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
}
}
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
detachedClient := &ssh.Client{}
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
star.setTransport(baseClient, nil)
forwardClient := &StarSSH{}
forwardClient.setTransport(detachedClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return forwardClient, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != detachedClient {
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start detached local forward: %v", err)
}
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
if string(reply) != "pong" {
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
}
if err := forwarder.Close(); err != nil {
t.Fatalf("close detached local forward: %v", err)
}
if detachedCalls.Load() != 1 {
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
}
if closeCalls.Load() != 1 {
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != baseClient {
t.Fatal("detached local forward should not detach the main ssh client")
}
if got := forwardClient.snapshotSSHClient(); got != nil {
t.Fatal("detached local forward should close its detached ssh client")
}
}
func echoForwardPipe(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return
}
_, _ = conn.Write(buf[:n])
}
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
t.Helper()
conn, err := net.DialTimeout("tcp", addr, time.Second)
if err != nil {
t.Fatalf("dial forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write forwarded payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read forwarded reply: %v", err)
}
return reply
}

15
go.mod
View File

@ -1,8 +1,17 @@
module b612.me/starssh module b612.me/starssh
go 1.16 go 1.20
require ( require (
github.com/pkg/sftp v1.13.4 github.com/Microsoft/go-winio v0.6.1
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 github.com/pkg/sftp v1.13.9
golang.org/x/crypto v0.33.0
golang.org/x/sys v0.30.0
)
require (
github.com/kr/fs v0.1.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
) )

93
go.sum
View File

@ -1,29 +1,92 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/pkg/sftp v1.13.4 h1:Lb0RYJCmgUcBgZosfoi9Y9sbl6+LJgOIgk/2Y4YjMFg= github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw=
github.com/pkg/sftp v1.13.4/go.mod h1:LzqnAvaD5TWeNBsZpfKxSYn1MbjWwOsCIAFFJbpIsK8= github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

290
hostkey.go Normal file
View File

@ -0,0 +1,290 @@
package starssh
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
type HostKeyFingerprintMismatchError struct {
Expected []string
ActualSHA256 string
ActualLegacyMD5 string
}
type AcceptNewHostKeyOptions struct {
KnownHostsFile string
HashHosts bool
IncludeRemoteAddress bool
FileMode os.FileMode
}
func (e *HostKeyFingerprintMismatchError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
}
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
trimmed := make([]string, 0, len(files))
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
trimmed = append(trimmed, file)
}
if len(trimmed) == 0 {
return nil, ErrKnownHostsFileRequired
}
return knownhosts.New(trimmed...)
}
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
KnownHostsFile: file,
IncludeRemoteAddress: true,
})
}
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
options = normalizeAcceptNewHostKeyOptions(options)
if options.KnownHostsFile == "" {
return nil, ErrKnownHostsFileRequired
}
state := &acceptNewHostKeyState{
file: options.KnownHostsFile,
hashHosts: options.HashHosts,
includeRemoteAddress: options.IncludeRemoteAddress,
fileMode: options.FileMode,
}
if err := state.reload(); err != nil {
return nil, err
}
return state.checkHostKey, nil
}
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
normalized := make([]string, 0, len(fingerprints))
seen := make(map[string]struct{}, len(fingerprints))
for _, raw := range fingerprints {
fingerprint, err := normalizeHostKeyFingerprint(raw)
if err != nil {
return nil, err
}
if fingerprint == "" {
continue
}
if _, exists := seen[fingerprint]; exists {
continue
}
seen[fingerprint] = struct{}{}
normalized = append(normalized, fingerprint)
}
if len(normalized) == 0 {
return nil, ErrHostFingerprintRequired
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
actualSHA256 := ssh.FingerprintSHA256(key)
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
for _, want := range normalized {
if want == actualSHA256 || want == actualMD5 {
return nil
}
}
return &HostKeyFingerprintMismatchError{
Expected: append([]string(nil), normalized...),
ActualSHA256: actualSHA256,
ActualLegacyMD5: actualMD5,
}
}, nil
}
type acceptNewHostKeyState struct {
file string
hashHosts bool
includeRemoteAddress bool
fileMode os.FileMode
mu sync.Mutex
cb ssh.HostKeyCallback
}
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cb != nil {
err := s.cb(hostname, remote, key)
if err == nil {
return nil
}
var keyErr *knownhosts.KeyError
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
return err
}
}
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
if err != nil {
return err
}
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
return err
}
if err := s.reload(); err != nil {
return err
}
if s.cb == nil {
return errors.New("known_hosts callback is nil after reload")
}
return s.cb(hostname, remote, key)
}
func (s *acceptNewHostKeyState) reload() error {
callback, err := loadKnownHostsCallback(s.file)
if err != nil {
return err
}
s.cb = callback
return nil
}
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
_, err := os.Stat(file)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, err
}
return knownhosts.New(file)
}
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
if options.FileMode == 0 {
options.FileMode = 0o600
}
return options
}
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
if len(addresses) == 0 {
return "", errors.New("no hostname or remote address available for known_hosts entry")
}
patterns := make([]string, 0, len(addresses))
for _, address := range addresses {
normalized := knownhosts.Normalize(address)
if hashHosts {
patterns = append(patterns, knownhosts.HashHostname(normalized))
continue
}
patterns = append(patterns, normalized)
}
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
return strings.Join(patterns, ",") + " " + authorizedKey, nil
}
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
addresses := make([]string, 0, 2)
seen := make(map[string]struct{}, 2)
add := func(address string) {
address = strings.TrimSpace(address)
if address == "" {
return
}
normalized := knownhosts.Normalize(address)
if _, exists := seen[normalized]; exists {
return
}
seen[normalized] = struct{}{}
addresses = append(addresses, address)
}
add(hostname)
if includeRemoteAddress && remote != nil {
add(remote.String())
}
return addresses
}
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
if strings.TrimSpace(file) == "" {
return ErrKnownHostsFileRequired
}
if strings.TrimSpace(line) == "" {
return errors.New("known_hosts line is empty")
}
dir := filepath.Dir(file)
if dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
}
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
if err != nil {
return err
}
defer handle.Close()
if _, err := handle.WriteString(line + "\n"); err != nil {
return err
}
return handle.Chmod(mode)
}
func normalizeHostKeyFingerprint(raw string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", nil
}
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
suffix := strings.TrimSpace(value[len("SHA256:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return "SHA256:" + suffix, nil
}
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
suffix := strings.TrimSpace(value[len("MD5:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return normalizeMD5Fingerprint("MD5:" + suffix), nil
}
if strings.Count(value, ":") >= 2 {
return normalizeMD5Fingerprint("MD5:" + value), nil
}
return "SHA256:" + value, nil
}
func normalizeMD5Fingerprint(value string) string {
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
return "MD5:" + strings.ToLower(value)
}
return "MD5:" + strings.ToLower(value[len("MD5:"):])
}

135
keepalive.go Normal file
View File

@ -0,0 +1,135 @@
package starssh
import (
"context"
"sync"
"time"
)
var sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
return err
}
func (s *StarSSH) Ping() error {
return s.PingContext(context.Background())
}
func (s *StarSSH) PingContext(ctx context.Context) error {
return s.pingContext(ctx, nil)
}
func (s *StarSSH) PingContextCloseOnCancel(ctx context.Context) error {
return s.pingContext(ctx, s.Close)
}
func (s *StarSSH) pingContext(ctx context.Context, onCancel func() error) error {
if ctx == nil {
ctx = context.Background()
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
runCancel := func() {}
if onCancel != nil {
var cancelOnce sync.Once
runCancel = func() {
cancelOnce.Do(func() {
_ = onCancel()
})
}
cancelDone := make(chan struct{})
defer close(cancelDone)
go func() {
select {
case <-ctx.Done():
runCancel()
case <-cancelDone:
}
}()
}
requestFunc := sendKeepAliveRequest
pingErr := make(chan error, 1)
go func() {
err := requestFunc(ctx, client)
select {
case pingErr <- err:
default:
}
}()
select {
case err := <-pingErr:
return err
case <-ctx.Done():
runCancel()
return ctx.Err()
}
}
func (s *StarSSH) startAutoKeepAlive() {
if s == nil || s.snapshotSSHClient() == nil {
return
}
interval := s.LoginInfo.KeepAliveInterval
if interval <= 0 {
return
}
timeout := s.LoginInfo.KeepAliveTimeout
if timeout <= 0 {
timeout = defaultKeepAliveTimeout
}
stop := make(chan struct{})
done := make(chan struct{})
s.keepaliveMu.Lock()
if s.keepaliveStop != nil {
s.keepaliveMu.Unlock()
return
}
s.keepaliveStop = stop
s.keepaliveDone = done
s.keepaliveMu.Unlock()
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
defer close(done)
for {
select {
case <-stop:
return
case <-ticker.C:
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
err := s.pingContext(pingCtx, nil)
cancel()
if err != nil {
s.closeFromKeepAlive()
return
}
}
}
}()
}
func (s *StarSSH) stopAutoKeepAlive() {
stop, done := s.takeKeepaliveHandles()
if stop != nil {
close(stop)
}
if done != nil {
<-done
}
}
func (s *StarSSH) closeFromKeepAlive() {
_ = s.closeTransport(false)
}

53
keepalive_test.go Normal file
View File

@ -0,0 +1,53 @@
package starssh
import (
"context"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestAutoKeepAliveTimeoutDoesNotDeadlock(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
return ctx.Err()
}
closeSSHClient = func(client sshClientRequester) error {
return nil
}
client := &ssh.Client{}
star := &StarSSH{
LoginInfo: LoginInput{
KeepAliveInterval: 10 * time.Millisecond,
KeepAliveTimeout: 20 * time.Millisecond,
},
}
star.setTransport(client, nil)
star.startAutoKeepAlive()
star.keepaliveMu.Lock()
done := star.keepaliveDone
star.keepaliveMu.Unlock()
if done == nil {
t.Fatal("keepalive did not start")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("keepalive goroutine did not exit after keepalive timeout")
}
if client := star.snapshotSSHClient(); client != nil {
t.Fatal("ssh client should be closed after keepalive timeout")
}
}

362
login.go Normal file
View File

@ -0,0 +1,362 @@
package starssh
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
}
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return loginWithContext(ctx, info)
}
func Login(info LoginInput) (*StarSSH, error) {
return LoginContext(context.Background(), info)
}
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
info = normalizeLoginInput(info)
if info.HostKeyCallback == nil {
return nil, ErrHostKeyCallbackRequired
}
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
defer cancel()
sshInfo := &StarSSH{
LoginInfo: info,
}
auth, authCleanup, err := buildAuthMethods(info)
if err != nil {
return nil, err
}
if authCleanup != nil {
defer authCleanup()
}
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
return info.HostKeyCallback(hostname, remote, key)
}
bannerCallback := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig := &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: info.Timeout,
HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback,
}
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
clientConfig.Config = ssh.Config{
Ciphers: info.Ciphers,
MACs: info.MACs,
KeyExchanges: info.KeyExchanges,
}
}
targetAddr := joinHostPort(info.Addr, info.Port)
rawConn, upstream, err := dialTargetConn(loginCtx, info)
if err != nil {
return sshInfo, err
}
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
if err != nil {
_ = rawConn.Close()
if upstream != nil {
_ = upstream.Close()
}
return sshInfo, err
}
client := ssh.NewClient(clientConn, chans, reqs)
sshInfo.setTransport(client, upstream)
if sshInfo.PublicKey != nil {
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
sshInfo.startAutoKeepAlive()
return sshInfo, nil
}
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
if timeout <= 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, timeout)
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
info := LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
User: user,
HostKeyCallback: DefaultAllowHostKeyCallback,
}
if prikeyPath != "" {
prikey, err := os.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 {
info.Port = defaultSSHPort
}
if info.Timeout <= 0 {
info.Timeout = defaultLoginTimeout
}
return info
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
if password != "" {
return ssh.Password(password)
}
if callback != nil {
return ssh.PasswordCallback(callback)
}
return nil
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(challenge)
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
conn, err := dialSSHAgent(timeout)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
signers, err := agent.NewClient(conn).Signers()
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
return ssh.PublicKeys(signers...), func() {
_ = conn.Close()
}, nil
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}

466
pool.go Normal file
View File

@ -0,0 +1,466 @@
package starssh
import (
"context"
"errors"
"sync"
"time"
)
const defaultExecPoolMaxOpenConns = 4
var ErrExecPoolClosed = errors.New("exec pool is closed")
type ExecPoolConfig struct {
Login LoginInput
MaxOpenConns int
MaxIdleConns int
MaxIdleTime time.Duration
DisableHealthCheck bool
HealthCheckTimeout time.Duration
}
type ExecPoolStats struct {
MaxOpenConns int
MaxIdleConns int
MaxIdleTime time.Duration
OpenConns int
IdleConns int
InUseConns int
}
type ExecPool struct {
loginInfo LoginInput
maxOpen int
maxIdle int
maxIdleTime time.Duration
idle chan *pooledClient
done chan struct{}
closeOnce sync.Once
healthCheckOnAcquire bool
healthCheckTimeout time.Duration
mu sync.Mutex
open int
closed bool
}
type pooledClient struct {
client *StarSSH
idleAt time.Time
}
func NewExecPool(config ExecPoolConfig) *ExecPool {
maxOpen := config.MaxOpenConns
if maxOpen <= 0 {
maxOpen = defaultExecPoolMaxOpenConns
}
maxIdle := config.MaxIdleConns
if maxIdle <= 0 || maxIdle > maxOpen {
maxIdle = maxOpen
}
return &ExecPool{
loginInfo: config.Login,
maxOpen: maxOpen,
maxIdle: maxIdle,
maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime),
idle: make(chan *pooledClient, maxIdle),
done: make(chan struct{}),
healthCheckOnAcquire: !config.DisableHealthCheck,
healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout),
}
}
func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) {
client, err := p.Acquire(ctx)
if err != nil {
return nil, err
}
result, execErr := client.Exec(ctx, req)
if execErr != nil {
p.Discard(client)
return result, execErr
}
if releaseErr := p.Release(client); releaseErr != nil {
return result, releaseErr
}
return result, nil
}
func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) {
return p.Exec(ctx, ExecRequest{
Command: command,
})
}
func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
client, err := p.Acquire(ctx)
if err != nil {
return nil, err
}
result, execErr := client.ExecStream(ctx, req, onChunk)
if execErr != nil {
p.Discard(client)
return result, execErr
}
if releaseErr := p.Release(client); releaseErr != nil {
return result, releaseErr
}
return result, nil
}
func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error {
if p == nil {
return errors.New("exec pool is nil")
}
if ctx == nil {
ctx = context.Background()
}
targetIdle = p.normalizeWarmUpTarget(targetIdle)
if targetIdle == 0 {
return nil
}
for {
if err := ctx.Err(); err != nil {
return err
}
idleCount, create, err := p.tryWarmUp(targetIdle)
if err != nil {
return err
}
if idleCount >= targetIdle || !create {
return nil
}
conn, err := LoginContext(ctx, p.loginInfo)
if err != nil {
p.releaseSlot()
return err
}
if err := p.Release(conn); err != nil {
return err
}
}
}
func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) {
if p == nil {
return nil, errors.New("exec pool is nil")
}
if ctx == nil {
ctx = context.Background()
}
for {
idleClient, create, err := p.tryAcquire()
if err != nil {
return nil, err
}
if idleClient != nil {
client, ok := p.takeIdleClient(ctx, idleClient)
if !ok {
continue
}
return client, nil
}
if create {
conn, err := LoginContext(ctx, p.loginInfo)
if err != nil {
p.releaseSlot()
return nil, err
}
return conn, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-p.done:
return nil, ErrExecPoolClosed
case idleClient = <-p.idle:
if idleClient == nil {
continue
}
client, ok := p.takeIdleClient(ctx, idleClient)
if !ok {
continue
}
return client, nil
}
}
}
func (p *ExecPool) Release(client *StarSSH) error {
if p == nil {
return errors.New("exec pool is nil")
}
if client == nil {
p.releaseSlot()
return nil
}
p.mu.Lock()
if p.closed {
p.mu.Unlock()
p.closeClient(client)
return nil
}
select {
case p.idle <- &pooledClient{
client: client,
idleAt: time.Now(),
}:
p.mu.Unlock()
return nil
default:
p.mu.Unlock()
p.closeClient(client)
return nil
}
}
func (p *ExecPool) Discard(client *StarSSH) {
if p == nil {
return
}
if client == nil {
p.releaseSlot()
return
}
p.closeClient(client)
}
func (p *ExecPool) Stats() ExecPoolStats {
if p == nil {
return ExecPoolStats{}
}
p.mu.Lock()
defer p.mu.Unlock()
idleCount := len(p.idle)
openCount := p.open
inUseCount := openCount - idleCount
if inUseCount < 0 {
inUseCount = 0
}
return ExecPoolStats{
MaxOpenConns: p.maxOpen,
MaxIdleConns: p.maxIdle,
MaxIdleTime: p.maxIdleTime,
OpenConns: openCount,
IdleConns: idleCount,
InUseConns: inUseCount,
}
}
func (p *ExecPool) Close() error {
if p == nil {
return nil
}
var closeErr error
p.closeOnce.Do(func() {
p.mu.Lock()
p.closed = true
idleClients := p.drainIdleLocked()
p.mu.Unlock()
close(p.done)
for _, client := range idleClients {
if err := client.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
})
return closeErr
}
func (p *ExecPool) CloseIdle() error {
if p == nil {
return nil
}
p.mu.Lock()
idleClients := p.drainIdleLocked()
p.mu.Unlock()
var closeErr error
for _, client := range idleClients {
if err := client.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
return closeErr
}
func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, false, ErrExecPoolClosed
}
select {
case client := <-p.idle:
return client, false, nil
default:
}
if p.open < p.maxOpen {
p.open++
return nil, true, nil
}
return nil, false, nil
}
func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, false, ErrExecPoolClosed
}
idleCount := len(p.idle)
if idleCount >= targetIdle {
return idleCount, false, nil
}
if p.open >= p.maxOpen {
return idleCount, false, nil
}
p.open++
return idleCount, true, nil
}
func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int {
if p == nil || p.maxIdle <= 0 {
return 0
}
if targetIdle <= 0 {
return p.maxIdle
}
if targetIdle > p.maxIdle {
return p.maxIdle
}
return targetIdle
}
func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) {
if idleClient == nil {
return nil, false
}
if idleClient.client == nil {
p.releaseSlot()
return nil, false
}
if p.isIdleExpired(idleClient) {
p.closePooledClient(idleClient)
return nil, false
}
if err := p.healthCheckClient(ctx, idleClient.client); err != nil {
p.closePooledClient(idleClient)
return nil, false
}
return idleClient.client, true
}
func (p *ExecPool) isIdleExpired(client *pooledClient) bool {
if p == nil || client == nil || client.client == nil {
return false
}
if p.maxIdleTime <= 0 || client.idleAt.IsZero() {
return false
}
return time.Since(client.idleAt) >= p.maxIdleTime
}
func (p *ExecPool) drainIdleLocked() []*StarSSH {
clients := make([]*StarSSH, 0, len(p.idle))
for {
select {
case idleClient := <-p.idle:
if p.open > 0 {
p.open--
}
if idleClient == nil || idleClient.client == nil {
continue
}
clients = append(clients, idleClient.client)
default:
return clients
}
}
}
func (p *ExecPool) releaseSlot() {
p.mu.Lock()
defer p.mu.Unlock()
if p.open > 0 {
p.open--
}
}
func (p *ExecPool) closeClient(client *StarSSH) {
if client != nil {
_ = client.Close()
}
p.releaseSlot()
}
func (p *ExecPool) closePooledClient(client *pooledClient) {
if client == nil {
return
}
if client.client == nil {
p.releaseSlot()
return
}
p.closeClient(client.client)
}
func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error {
if client == nil {
return errors.New("ssh client is nil")
}
if !p.healthCheckOnAcquire {
return nil
}
if ctx == nil {
ctx = context.Background()
}
timeout := p.healthCheckTimeout
if timeout > 0 {
healthCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = healthCtx
}
return client.PingContext(ctx)
}
func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration {
if timeout <= 0 {
return defaultKeepAliveTimeout
}
return timeout
}
func normalizeMaxIdleTime(timeout time.Duration) time.Duration {
if timeout <= 0 {
return 0
}
return timeout
}

121
session.go Normal file
View File

@ -0,0 +1,121 @@
package starssh
import (
"errors"
"io"
"net"
"strings"
"golang.org/x/crypto/ssh"
)
func (s *StarSSH) Close() error {
return s.closeTransport(true)
}
func (s *StarSSH) NewSession() (*ssh.Session, error) {
return s.NewPTYSession(nil)
}
func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
return NewExecSession(client)
}
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
client, err := s.requireSSHClient()
if err != nil {
return nil, err
}
return NewPTYSession(client, config)
}
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
return NewExecSession(client)
}
func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
return client.NewSession()
}
func NewSession(client *ssh.Client) (*ssh.Session, error) {
return NewPTYSession(client, nil)
}
func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
session, err := client.NewSession()
if err != nil {
return nil, err
}
cfg := normalizeTerminalConfig(config)
if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func normalizeTerminalConfig(config *TerminalConfig) TerminalConfig {
cfg := TerminalConfig{
Term: defaultPTYTerm,
Rows: defaultPTYRows,
Columns: defaultPTYColumns,
Modes: ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
},
}
if config == nil {
return cfg
}
if strings.TrimSpace(config.Term) != "" {
cfg.Term = config.Term
}
if config.Rows > 0 {
cfg.Rows = config.Rows
}
if config.Columns > 0 {
cfg.Columns = config.Columns
}
if len(config.Modes) > 0 {
cfg.Modes = config.Modes
}
return cfg
}
func normalizeAlreadyClosedError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, net.ErrClosed) {
return nil
}
if errors.Is(err, io.EOF) {
return nil
}
if strings.Contains(err.Error(), "use of closed network connection") {
return nil
}
return err
}
func closeUpstream(upstream *StarSSH) error {
if upstream == nil {
return nil
}
return upstream.Close()
}

1607
sftp.go

File diff suppressed because it is too large Load Diff

475
sftp_test.go Normal file
View File

@ -0,0 +1,475 @@
package starssh
import (
"context"
"errors"
"io"
"net"
"os"
"path/filepath"
"strings"
"testing"
"github.com/pkg/sftp"
)
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
opts := normalizeSFTPTransferOptions(nil)
if !opts.AtomicUpload {
t.Fatal("expected atomic upload to default to enabled")
}
if !opts.AtomicDownload {
t.Fatal("expected atomic download to default to enabled")
}
}
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
verifyErr := errors.New("verify failed")
var verifiedPath string
oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc
sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error {
verifiedPath = remotePath
return verifyErr
}
t.Cleanup(func() {
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
})
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
if !errors.Is(err, verifyErr) {
t.Fatalf("expected verify failure, got %v", err)
}
if verifiedPath == remotePath {
t.Fatal("expected upload verification to run against temp path before final rename")
}
data, err := os.ReadFile(remotePath)
if err != nil {
t.Fatalf("read remote file: %v", err)
}
if string(data) != "original remote" {
t.Fatalf("remote target was replaced on verify failure: %q", string(data))
}
assertNoTransferTemps(t, remotePath)
}
func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remoteRealPath := filepath.Join(root, "remote-real.txt")
remotePath := filepath.Join(root, "remote-link.txt")
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil {
t.Fatalf("write remote backing file: %v", err)
}
if err := os.Symlink(remoteRealPath, remotePath); err != nil {
t.Skipf("symlink unsupported: %v", err)
}
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
if err == nil || !strings.Contains(err.Error(), "symlink") {
t.Fatalf("expected symlink rejection, got %v", err)
}
info, err := os.Lstat(remotePath)
if err != nil {
t.Fatalf("lstat remote symlink: %v", err)
}
if info.Mode()&os.ModeSymlink == 0 {
t.Fatal("expected remote target to remain a symlink")
}
data, err := os.ReadFile(remoteRealPath)
if err != nil {
t.Fatalf("read remote backing file: %v", err)
}
if string(data) != "original remote" {
t.Fatalf("remote backing file changed unexpectedly: %q", string(data))
}
assertNoTransferTemps(t, remotePath)
}
func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote-dir")
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.Mkdir(remotePath, 0o755); err != nil {
t.Fatalf("mkdir remote target: %v", err)
}
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
if err == nil || !strings.Contains(err.Error(), "directory") {
t.Fatalf("expected directory rejection, got %v", err)
}
info, err := os.Stat(remotePath)
if err != nil {
t.Fatalf("stat remote directory: %v", err)
}
if !info.IsDir() {
t.Fatal("expected remote target to remain a directory")
}
assertNoTransferTemps(t, remotePath)
}
func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.Chmod(remotePath, 0o755); err != nil {
t.Fatalf("chmod remote file: %v", err)
}
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
t.Fatalf("transfer out: %v", err)
}
assertMode(t, remotePath, 0o755)
assertFileContent(t, remotePath, "new payload")
}
func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.Chmod(localPath, 0o751); err != nil {
t.Fatalf("chmod local file: %v", err)
}
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
t.Fatalf("transfer out: %v", err)
}
assertMode(t, remotePath, 0o751)
assertFileContent(t, remotePath, "new payload")
}
func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.Chmod(remotePath, 0o755); err != nil {
t.Fatalf("chmod remote file: %v", err)
}
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
t.Fatalf("transfer out bytes: %v", err)
}
assertMode(t, remotePath, 0o755)
assertFileContent(t, remotePath, "byte payload")
}
func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.txt")
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
verifyErr := errors.New("verify local failed")
var verifiedPath string
oldVerifyLocalSize := sftpVerifyLocalSizeFunc
sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error {
verifiedPath = localPath
return verifyErr
}
t.Cleanup(func() {
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
})
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
if !errors.Is(err, verifyErr) {
t.Fatalf("expected verify failure, got %v", err)
}
if verifiedPath == dstPath {
t.Fatal("expected download verification to run against temp path before final rename")
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("read local file: %v", err)
}
if string(data) != "original local" {
t.Fatalf("local target was replaced on verify failure: %q", string(data))
}
assertNoTransferTemps(t, dstPath)
}
func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
localRealPath := filepath.Join(root, "local-real.txt")
dstPath := filepath.Join(root, "local-link.txt")
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil {
t.Fatalf("write local backing file: %v", err)
}
if err := os.Symlink(localRealPath, dstPath); err != nil {
t.Skipf("symlink unsupported: %v", err)
}
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
if err == nil || !strings.Contains(err.Error(), "symlink") {
t.Fatalf("expected symlink rejection, got %v", err)
}
info, err := os.Lstat(dstPath)
if err != nil {
t.Fatalf("lstat local symlink: %v", err)
}
if info.Mode()&os.ModeSymlink == 0 {
t.Fatal("expected local target to remain a symlink")
}
assertFileContent(t, localRealPath, "original local")
assertNoTransferTemps(t, dstPath)
}
func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local-dir")
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.Mkdir(dstPath, 0o755); err != nil {
t.Fatalf("mkdir local target: %v", err)
}
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
if err == nil || !strings.Contains(err.Error(), "directory") {
t.Fatalf("expected directory rejection, got %v", err)
}
info, err := os.Stat(dstPath)
if err != nil {
t.Fatalf("stat local directory: %v", err)
}
if !info.IsDir() {
t.Fatal("expected local target to remain a directory")
}
assertNoTransferTemps(t, dstPath)
}
func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.sh")
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil {
t.Fatalf("write local file: %v", err)
}
if err := os.Chmod(dstPath, 0o755); err != nil {
t.Fatalf("chmod local file: %v", err)
}
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
t.Fatalf("transfer in: %v", err)
}
assertMode(t, dstPath, 0o755)
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
}
func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.sh")
dstPath := filepath.Join(root, "local.sh")
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.Chmod(srcPath, 0o751); err != nil {
t.Fatalf("chmod remote file: %v", err)
}
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
t.Fatalf("transfer in: %v", err)
}
assertMode(t, dstPath, 0o751)
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
}
func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.txt")
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
copyErr := errors.New("copy failed")
var copyTargetPath string
oldCopy := sftpCopyWithProgressFunc
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
file, ok := dst.(*os.File)
if !ok {
t.Fatalf("expected local temp file writer, got %T", dst)
}
copyTargetPath = file.Name()
buf := make([]byte, 8)
n, readErr := src.Read(buf)
if readErr != nil && !errors.Is(readErr, io.EOF) {
return 0, readErr
}
if n > 0 {
written, err := dst.Write(buf[:n])
if err != nil {
return int64(written), err
}
return int64(written), copyErr
}
return 0, copyErr
}
t.Cleanup(func() {
sftpCopyWithProgressFunc = oldCopy
})
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
if !errors.Is(err, copyErr) {
t.Fatalf("expected copy failure, got %v", err)
}
if copyTargetPath == dstPath {
t.Fatal("expected partial download writes to stay on temp path")
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("read local file: %v", err)
}
if string(data) != "original local" {
t.Fatalf("local target was modified by partial download: %q", string(data))
}
assertNoTransferTemps(t, dstPath)
}
func newSFTPTestClient(t *testing.T) *sftp.Client {
t.Helper()
serverConn, clientConn := net.Pipe()
server, err := sftp.NewServer(serverConn)
if err != nil {
t.Fatalf("create sftp server: %v", err)
}
serveErrCh := make(chan error, 1)
go func() {
serveErrCh <- server.Serve()
}()
client, err := sftp.NewClientPipe(clientConn, clientConn)
if err != nil {
_ = server.Close()
t.Fatalf("create sftp client: %v", err)
}
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
serveErr := <-serveErrCh
if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil {
return
}
t.Errorf("unexpected sftp server error: %v", serveErr)
})
return client
}
func assertNoTransferTemps(t *testing.T, targetPath string) {
t.Helper()
matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*")
if err != nil {
t.Fatalf("glob temp files: %v", err)
}
if len(matches) != 0 {
t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches)
}
}
func assertMode(t *testing.T, targetPath string, want os.FileMode) {
t.Helper()
info, err := os.Stat(targetPath)
if err != nil {
t.Fatalf("stat %q: %v", targetPath, err)
}
if got := info.Mode().Perm(); got != want {
t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want)
}
}
func assertFileContent(t *testing.T, targetPath string, want string) {
t.Helper()
data, err := os.ReadFile(targetPath)
if err != nil {
t.Fatalf("read %q: %v", targetPath, err)
}
if string(data) != want {
t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want)
}
}

592
shell.go Normal file
View File

@ -0,0 +1,592 @@
package starssh
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strings"
"time"
)
var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells")
type ShellRequest struct {
Command string
Timeout time.Duration
Keyword string
UseWaitDefault *bool
}
// NewShell creates the legacy prompt-driven POSIX shell helper.
// For raw interactive terminal flows, prefer NewTerminal.
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
shell = &StarShell{
UseWaitDefault: true,
WaitTimeout: defaultShellWaitTimeout,
isecho: true,
iscolor: true,
promptToken: defaultShellPromptToken,
}
shell.Session, err = s.NewPTYSession(nil)
if err != nil {
return nil, err
}
shell.in, err = shell.Session.StdinPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
stdout, err := shell.Session.StdoutPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.out = bufio.NewReader(stdout)
stderr, err := shell.Session.StderrPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.er = bufio.NewReader(stderr)
if err := shell.Session.Shell(); err != nil {
_ = shell.Session.Close()
return nil, err
}
go shell.watchSession()
shell.gohub()
if err := shell.configurePrompt(context.Background()); err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.Clear()
return shell, nil
}
func (s *StarShell) configurePrompt(ctx context.Context) error {
if s == nil {
return errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout)
defer cancel()
ctx = timeoutCtx
}
prompt := s.promptToken + " "
setupCommands := []string{
"unset PROMPT_COMMAND >/dev/null 2>&1 || true",
fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)),
}
s.Clear()
for _, cmd := range setupCommands {
if err := s.WriteCommand(cmd); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
}
probeToken := "__STARSSH_POSIX_READY__" + newNonce(6)
if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
ticker := time.NewTicker(defaultShellPollInterval)
defer ticker.Stop()
for {
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr)
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
if strings.Contains(outs, probeToken) {
s.Clear()
return nil
}
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
if looksLikeNonPOSIXShellError(errs) {
return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs)
}
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly)
}
return ctx.Err()
case <-ticker.C:
}
}
}
func (s *StarShell) watchSession() {
if s == nil || s.Session == nil {
return
}
if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) {
s.setError(err)
}
}
func (s *StarShell) Close() error {
if s == nil {
return nil
}
var closeErr error
s.closeOnce.Do(func() {
if s.Session == nil {
return
}
closeErr = s.Session.Close()
})
return closeErr
}
func (s *StarShell) SwitchNoColor(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.iscolor = run
}
func (s *StarShell) SwitchEcho(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isecho = run
}
func (s *StarShell) TrimColor(str string) string {
s.rw.RLock()
shouldTrim := s.iscolor
s.rw.RUnlock()
if shouldTrim {
return SedColor(str)
}
return str
}
/*
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
*/
func (s *StarShell) SwitchPrint(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isprint = run
}
/*
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
*/
func (s *StarShell) SwitchFunc(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isfuncs = run
}
func (s *StarShell) SetFunc(funcs func(string)) {
s.rw.Lock()
defer s.rw.Unlock()
s.funcs = funcs
}
func (s *StarShell) Clear() {
s.rw.Lock()
defer s.rw.Unlock()
s.outbyte = []byte{}
s.errbyte = []byte{}
}
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
s.Clear()
defer s.Clear()
return s.Shell(cmd, sleep)
}
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
s.commandMu.Lock()
defer s.commandMu.Unlock()
if err := s.WriteCommand(cmd); err != nil {
return "", "", err
}
outRaw, errRaw, runErr := s.GetResult(sleep)
if runErr != nil {
return "", "", runErr
}
outText := s.TrimColor(strings.TrimSpace(string(outRaw)))
s.rw.RLock()
echoEnabled := s.isecho
s.rw.RUnlock()
if echoEnabled {
outText = stripCommandEchoFromOutput(outText, cmd)
}
return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil
}
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
result, err := s.Run(context.Background(), ShellRequest{
Command: cmd,
})
if err != nil {
return "", "", err
}
return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError()
}
func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) {
return s.Run(ctx, ShellRequest{
Command: command,
})
}
func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) {
if s == nil {
return nil, errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
s.commandMu.Lock()
defer s.commandMu.Unlock()
if strings.TrimSpace(req.Command) == "" {
return nil, errors.New("command is empty")
}
s.rw.RLock()
useDefault := s.UseWaitDefault
keyword := s.Keyword
waitTimeout := s.WaitTimeout
promptToken := s.promptToken
s.rw.RUnlock()
if req.UseWaitDefault != nil {
useDefault = *req.UseWaitDefault
}
if strings.TrimSpace(req.Keyword) != "" {
keyword = req.Keyword
}
if req.Timeout > 0 {
waitTimeout = req.Timeout
}
if !useDefault && keyword == "" {
return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set")
}
if waitTimeout <= 0 {
waitTimeout = defaultShellWaitTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
ctx = timeoutCtx
}
startAt := time.Now()
s.Clear()
defer s.Clear()
beginToken, endToken := newCommandTokens()
markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken)
if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil {
return nil, err
}
if err := s.WriteCommand(req.Command); err != nil {
return nil, err
}
if err := s.WriteCommand(markerCmd); err != nil {
return nil, err
}
var (
outc string
errc string
exitCode int
done bool
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(defaultShellPollInterval):
}
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return nil, runErr
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
s.rw.RLock()
useDefault = s.UseWaitDefault
keyword = s.Keyword
s.rw.RUnlock()
if useDefault {
segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken)
if parseErr != nil {
return nil, parseErr
}
if found {
outc = segment
errc = errs
exitCode = rc
done = true
break
}
}
if keyword != "" {
if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) {
outc = outs
errc = errs
done = true
break
}
}
}
if !done {
return nil, errors.New("failed to collect shell result")
}
outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken)
errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken)
outc = stripCommandEchoFromOutput(outc, req.Command)
stdoutText := strings.TrimSpace(outc)
stderrText := strings.TrimSpace(errc)
result := &ExecResult{
Command: req.Command,
Stdout: []byte(stdoutText),
Stderr: []byte(stderrText),
Combined: combineCommandOutput(stdoutText, stderrText),
Duration: time.Since(startAt),
}
if useDefault {
result.ExitCode = exitCode
}
return result, nil
}
func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) {
lines := strings.Split(stdout, "\n")
beginLine := -1
for i, line := range lines {
if strings.TrimSpace(line) == beginToken {
beginLine = i
break
}
}
if beginLine < 0 {
return "", 0, false, nil
}
segment := strings.Join(lines[beginLine+1:], "\n")
before, rc, found, err := splitByEndToken(segment, endToken)
if err != nil {
return "", 0, false, err
}
if !found {
return "", 0, false, nil
}
return strings.TrimSpace(before), rc, true, nil
}
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
if sleep > 0 {
time.Sleep(time.Millisecond * time.Duration(sleep))
}
return s.readState()
}
func (s *StarShell) WriteCommand(cmd string) error {
return s.Write([]byte(cmd + "\n"))
}
func (s *StarShell) Write(bstr []byte) error {
if s == nil {
return errors.New("shell is nil")
}
if s.in == nil {
return errors.New("shell stdin is not initialized")
}
if _, _, runErr := s.readState(); runErr != nil {
return runErr
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.in.Write(bstr)
return err
}
func (s *StarShell) gohub() {
if s.er != nil {
go s.streamPump(s.er, true)
}
if s.out != nil {
go s.streamPump(s.out, false)
}
}
func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) {
var cache []byte
for {
read, err := reader.ReadByte()
if err == io.EOF {
return
}
if err != nil {
s.setError(err)
return
}
s.rw.Lock()
if isStderr {
s.errbyte = append(s.errbyte, read)
} else {
s.outbyte = append(s.outbyte, read)
}
printEnabled := s.isprint
funcEnabled := s.isfuncs && s.funcs != nil
lineHandler := s.funcs
trimColor := s.iscolor
s.rw.Unlock()
if printEnabled {
fmt.Print(string([]byte{read}))
}
cache = append(cache, read)
if read == '\n' {
if funcEnabled {
line := strings.TrimSpace(string(cache))
if trimColor {
line = SedColor(line)
}
go lineHandler(line)
}
cache = cache[:0]
}
}
}
func (s *StarShell) setError(err error) {
if err == nil || errors.Is(err, io.EOF) {
return
}
s.rw.Lock()
defer s.rw.Unlock()
if s.errors == nil {
s.errors = err
}
}
func (s *StarShell) readState() ([]byte, []byte, error) {
s.rw.RLock()
defer s.rw.RUnlock()
outCopy := make([]byte, len(s.outbyte))
copy(outCopy, s.outbyte)
errCopy := make([]byte, len(s.errbyte))
copy(errCopy, s.errbyte)
if s.errors != nil {
return outCopy, errCopy, s.errors
}
return outCopy, errCopy, nil
}
func stripCommandEchoFromOutput(output string, cmd string) string {
if output == "" || cmd == "" {
return strings.TrimSpace(output)
}
lines := strings.Split(output, "\n")
cmdLines := strings.Split(cmd, "\n")
for _, cmdLine := range cmdLines {
trimmedCmd := strings.TrimSpace(cmdLine)
if trimmedCmd == "" {
continue
}
for i, line := range lines {
if strings.TrimSpace(line) == trimmedCmd {
lines = append(lines[:i], lines[i+1:]...)
break
}
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
func looksLikeNonPOSIXShellError(output string) bool {
if strings.TrimSpace(output) == "" {
return false
}
lower := strings.ToLower(output)
indicators := []string{
"is not recognized as an internal or external command",
"the term ",
"command not found",
"unknown command",
"not found",
"not recognized",
}
for _, indicator := range indicators {
if strings.Contains(lower, indicator) {
return true
}
}
return false
}
func (s *StarShell) GetUid() string {
res, _, _ := s.ShellWait("id -u")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGid() string {
res, _, _ := s.ShellWait("id -g")
return strings.TrimSpace(res)
}
func (s *StarShell) GetUser() string {
res, _, _ := s.ShellWait("id -un")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGroup() string {
res, _, _ := s.ShellWait("id -gn")
return strings.TrimSpace(res)
}

636
ssh.go
View File

@ -1,636 +0,0 @@
package starssh
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"net"
"regexp"
"strings"
"sync"
"time"
)
type StarSSH struct {
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
}
type LoginInput struct {
KeyExchanges []string
Ciphers []string
MACs []string
User string
Password string
Prikey string
PrikeyPwd string
Addr string
Port int
Timeout time.Duration
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
BannerCallback func(string) error
}
type StarShell struct {
Keyword string
UseWaitDefault bool
Session *ssh.Session
in io.Writer
out *bufio.Reader
er *bufio.Reader
outbyte []byte
errbyte []byte
lastout int64
errors error
isprint bool
isfuncs bool
iscolor bool
isecho bool
rw sync.RWMutex
funcs func(string)
}
func Login(info LoginInput) (*StarSSH, error) {
var (
auth []ssh.AuthMethod
clientConfig *ssh.ClientConfig
config ssh.Config
err error
)
sshInfo := new(StarSSH)
// get auth method
auth = make([]ssh.AuthMethod, 0)
if info.Prikey == "" {
keyboardInteractiveChallenge := func(
user,
instruction string,
questions []string,
echos []bool,
) (answers []string, err error) {
if len(questions) == 0 {
return []string{}, nil
}
return []string{info.Password}, nil
}
auth = append(auth, ssh.Password(info.Password))
auth = append(auth, ssh.KeyboardInteractive(keyboardInteractiveChallenge))
} else {
pemBytes := []byte(info.Prikey)
var signer ssh.Signer
if info.PrikeyPwd == "" {
signer, err = ssh.ParsePrivateKey(pemBytes)
} else {
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
}
if err != nil {
return nil, err
}
auth = append(auth, ssh.PublicKeys(signer))
}
if len(info.Ciphers) == 0 {
config = ssh.Config{
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc", "chacha20-poly1305@openssh.com"},
}
} else {
config = ssh.Config{
Ciphers: info.Ciphers,
}
}
if len(info.MACs) != 0 {
config.MACs = info.MACs
}
if len(info.KeyExchanges) != 0 {
config.KeyExchanges = info.KeyExchanges
}
if info.Timeout == 0 {
info.Timeout = time.Second * 5
}
hostKeycbfunc := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
if info.HostKeyCallback != nil {
return info.HostKeyCallback(hostname, remote, key)
}
return nil
}
bannercbfunc := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig = &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: info.Timeout,
Config: config,
HostKeyCallback: hostKeycbfunc,
BannerCallback: bannercbfunc,
}
// connet to ssh
sshInfo.LoginInfo = info
sshInfo.Client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", info.Addr, info.Port), clientConfig)
if err == nil && sshInfo.PublicKey != nil {
sshInfo.online = true
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
return sshInfo, err
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
var info = LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
User: user,
}
if prikeyPath != "" {
prikey, err := ioutil.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
var outc, errc string = " ", " "
s.Clear()
defer s.Clear()
echo := "echo b7Y85R56TUY6R5UTb612"
err := s.WriteCommand(cmd)
if err != nil {
return "", "", err
}
time.Sleep(time.Millisecond * 20)
err = s.WriteCommand(echo)
if err != nil {
return "", "", err
}
for {
time.Sleep(time.Millisecond * 120)
outs := string(s.outbyte)
errs := string(s.errbyte)
outs = strings.TrimSpace(strings.ReplaceAll(outs, "\r\n", "\n"))
errs = strings.TrimSpace(strings.ReplaceAll(errs, "\r\n", "\n"))
if len(outs) >= len(cmd+"\n"+echo) && outs[0:len(cmd+"\n"+echo)] == cmd+"\n"+echo {
outs = outs[len(cmd+"\n"+echo):]
} else if len(outs) >= len(cmd) && outs[0:len(cmd)] == cmd {
outs = outs[len(cmd):]
}
if len(errs) >= len(cmd) && errs[0:len(cmd)] == cmd {
errs = errs[len(cmd):]
}
if s.UseWaitDefault {
if strings.Index(string(outs), "b7Y85R56TUY6R5UTb612") >= 0 {
list := strings.Split(string(outs), "\n")
for _, v := range list {
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
outc += v + "\n"
}
}
break
}
if strings.Index(string(errs), "b7Y85R56TUY6R5UTb612") >= 0 {
list := strings.Split(string(errs), "\n")
for _, v := range list {
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
errc += v + "\n"
}
}
break
}
}
if s.Keyword != "" {
if strings.Index(string(outs), s.Keyword) >= 0 {
list := strings.Split(string(outs), "\n")
for _, v := range list {
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
outc += v + "\n"
}
}
break
}
if strings.Index(string(errs), s.Keyword) >= 0 {
list := strings.Split(string(errs), "\n")
for _, v := range list {
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
errc += v + "\n"
}
}
break
}
}
}
return s.TrimColor(strings.TrimSpace(outc)), s.TrimColor(strings.TrimSpace(errc)), err
}
func (s *StarShell) Close() error {
return s.Session.Close()
}
func (s *StarShell) SwitchNoColor(is bool) {
s.iscolor = is
}
func (s *StarShell) SwitchEcho(is bool) {
s.isecho = is
}
func (s *StarShell) TrimColor(str string) string {
if s.iscolor {
return SedColor(str)
}
return str
}
/*
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
*/
func (s *StarShell) SwitchPrint(run bool) {
s.isprint = run
}
/*
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
*/
func (s *StarShell) SwitchFunc(run bool) {
s.isfuncs = run
}
func (s *StarShell) SetFunc(funcs func(string)) {
s.funcs = funcs
}
func (s *StarShell) Clear() {
defer s.rw.Unlock()
s.rw.Lock()
s.outbyte = []byte{}
s.errbyte = []byte{}
time.Sleep(time.Millisecond * 15)
}
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
defer s.Clear()
s.Clear()
return s.Shell(cmd, sleep)
}
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
if err := s.WriteCommand(cmd); err != nil {
return "", "", err
}
tmp1, tmp2, err := s.GetResult(sleep)
tmps := s.TrimColor(strings.TrimSpace(string(tmp1)))
if s.isecho {
n := len(strings.Split(cmd, "\n"))
if n == 1 {
list := strings.SplitN(tmps, "\n", 2)
if len(list) == 2 {
tmps = list[1]
}
} else {
list := strings.Split(tmps, "\n")
cmds := strings.Split(cmd, "\n")
for _, v := range cmds {
for k, v2 := range list {
if strings.TrimSpace(v2) == strings.TrimSpace(v) {
list[k] = ""
break
}
}
}
tmps = ""
for _, v := range list {
if v != "" {
tmps += v + "\n"
}
}
tmps = tmps[0 : len(tmps)-1]
}
}
return tmps, s.TrimColor(strings.TrimSpace(string(tmp2))), err
}
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
if s.errors != nil {
s.Session.Close()
return s.outbyte, s.errbyte, s.errors
}
if sleep > 0 {
time.Sleep(time.Millisecond * time.Duration(sleep))
}
return s.outbyte, s.errbyte, nil
}
func (s *StarShell) WriteCommand(cmd string) error {
return s.Write([]byte(cmd + "\n"))
}
func (s *StarShell) Write(bstr []byte) error {
if s.errors != nil {
s.Session.Close()
return s.errors
}
_, err := s.in.Write(bstr)
return err
}
func (s *StarShell) gohub() {
go func() {
var cache []byte
for {
read, err := s.er.ReadByte()
if err != nil {
s.errors = err
return
}
s.errbyte = append(s.errbyte, read)
if s.isprint {
fmt.Print(string([]byte{read}))
}
cache = append(cache, read)
if read == '\n' {
if s.isfuncs {
go s.funcs(s.TrimColor(strings.TrimSpace(string(cache))))
cache = []byte{}
}
}
}
}()
var cache []byte
for {
read, err := s.out.ReadByte()
if err != nil {
s.errors = err
return
}
s.rw.Lock()
s.outbyte = append(s.outbyte, read)
cache = append(cache, read)
s.rw.Unlock()
if read == '\n' {
if s.isfuncs {
go s.funcs(strings.TrimSpace(string(cache)))
cache = []byte{}
}
}
if s.isprint {
fmt.Print(string([]byte{read}))
}
}
}
func (s *StarShell) GetUid() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=uid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetGid() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=gid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetUser() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
return strings.TrimSpace(res)
}
func (s *StarShell) GetGroup() string {
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
shell = new(StarShell)
shell.Session, err = s.NewSession()
if err != nil {
return
}
shell.in, _ = shell.Session.StdinPipe()
tmp, _ := shell.Session.StdoutPipe()
shell.out = bufio.NewReader(tmp)
tmp, _ = shell.Session.StderrPipe()
shell.er = bufio.NewReader(tmp)
err = shell.Session.Shell()
shell.isecho = true
go shell.Session.Wait()
shell.UseWaitDefault = true
shell.WriteCommand("bash")
time.Sleep(500 * time.Millisecond)
shell.WriteCommand("export PS1= ")
shell.WriteCommand("export PS2= ")
go shell.gohub()
time.Sleep(500 * time.Millisecond)
shell.Clear()
shell.Clear()
return
}
func (s *StarSSH) Close() error {
if s.online {
return s.Client.Close()
}
return nil
}
func (s *StarSSH) NewSession() (*ssh.Session, error) {
return NewSession(s.Client)
}
func (s *StarSSH) ShellOne(cmd string) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
data, err := newsess.CombinedOutput(cmd)
newsess.Close()
return strings.TrimSpace(string(data)), err
}
func (s *StarSSH) Exists(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -e "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) GetUid() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=uid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetGid() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=gid\=)\d+"`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetUser() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) GetGroup() string {
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
return strings.TrimSpace(res)
}
func (s *StarSSH) IsFile(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -f "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) IsFolder(filepath string) bool {
res, _ := s.ShellOne(`echo 1 && [ ! -d "` + filepath + `" ] && echo 2`)
if res == "1" {
return true
} else {
return false
}
}
func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
var bytes, errbytes []byte
tmp, _ := newsess.StdoutPipe()
reader := bufio.NewReader(tmp)
tmp, _ = newsess.StderrPipe()
errder := bufio.NewReader(tmp)
err = newsess.Start(cmd)
if err != nil {
return "", err
}
c := make(chan int, 1)
go newsess.Wait()
go func() {
for {
byt, err := reader.ReadByte()
if err != nil {
break
}
fmt.Print(string([]byte{byt}))
bytes = append(bytes, byt)
}
c <- 1
}()
for {
byt, err := errder.ReadByte()
if err != nil {
break
}
fmt.Print(string([]byte{byt}))
errbytes = append(errbytes, byt)
}
_ = <-c
newsess.Close()
if len(errbytes) != 0 {
err = errors.New(strings.TrimSpace(string(errbytes)))
} else {
err = nil
}
return strings.TrimSpace(string(bytes)), err
}
func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) {
newsess, err := s.NewSession()
if err != nil {
return "", err
}
var bytes, errbytes []byte
tmp, _ := newsess.StdoutPipe()
reader := bufio.NewReader(tmp)
tmp, _ = newsess.StderrPipe()
errder := bufio.NewReader(tmp)
err = newsess.Start(cmd)
if err != nil {
return "", err
}
c := make(chan int, 1)
go newsess.Wait()
go func() {
for {
byt, err := reader.ReadByte()
if err != nil {
break
}
callback(string([]byte{byt}))
bytes = append(bytes, byt)
}
c <- 1
}()
for {
byt, err := errder.ReadByte()
if err != nil {
break
}
callback(string([]byte{byt}))
errbytes = append(errbytes, byt)
}
_ = <-c
newsess.Close()
if len(errbytes) != 0 {
err = errors.New(strings.TrimSpace(string(errbytes)))
} else {
err = nil
}
return strings.TrimSpace(string(bytes)), err
}
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
session, err := client.NewSession()
return session, err
}
func NewSession(client *ssh.Client) (*ssh.Session, error) {
var session *ssh.Session
var err error
// create session
if session, err = client.NewSession(); err != nil {
return nil, err
}
modes := ssh.TerminalModes{
ssh.ECHO: 1, // 还是要强制开启
//ssh.IGNCR: 0,
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
if err := session.RequestPty("xterm", 500, 250, modes); err != nil {
return nil, err
}
return session, nil
}
func SedColor(str string) string {
reg := regexp.MustCompile(`\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]`)
//fmt.Println("regexp:", reg.Match([]byte(str)))
return string(reg.ReplaceAll([]byte(str), []byte("")))
}

21
sshagent_unix.go Normal file
View File

@ -0,0 +1,21 @@
//go:build !windows
package starssh
import (
"net"
"os"
"strings"
"time"
)
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if agentSock == "" {
return nil, errSSHAgentUnavailable
}
if timeout > 0 {
return net.DialTimeout("unix", agentSock, timeout)
}
return net.Dial("unix", agentSock)
}

70
sshagent_windows.go Normal file
View File

@ -0,0 +1,70 @@
//go:build windows
package starssh
import (
"context"
"errors"
"net"
"os"
"strings"
"time"
"github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
)
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if agentSock != "" {
return dialWindowsSSHAgentEndpoint(agentSock, timeout)
}
return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true)
}
func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) {
if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
return dialWindowsNamedPipe(pipePath, timeout, false)
}
if timeout > 0 {
return net.DialTimeout("unix", endpoint, timeout)
}
return net.Dial("unix", endpoint)
}
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
conn, err := winio.DialPipeContext(ctx, path)
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
return nil, errSSHAgentUnavailable
}
return conn, err
}
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
trimmed := strings.TrimSpace(endpoint)
if trimmed == "" {
return "", false
}
normalized := trimmed
if strings.HasPrefix(normalized, "//./pipe/") {
normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/")
}
if strings.HasPrefix(normalized, `\\.\pipe\`) {
return normalized, true
}
return "", false
}
func isWindowsPipeUnavailable(err error) bool {
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
}

106
state.go Normal file
View File

@ -0,0 +1,106 @@
package starssh
import (
"errors"
"golang.org/x/crypto/ssh"
)
type sshClientRequester interface {
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
Close() error
}
var closeSSHClient = func(client sshClientRequester) error {
if client == nil {
return nil
}
return client.Close()
}
func (s *StarSSH) snapshotSSHClient() *ssh.Client {
if s == nil {
return nil
}
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return s.Client
}
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
client := s.snapshotSSHClient()
if client == nil {
return nil, errors.New("ssh client is nil")
}
return client, nil
}
func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
if s == nil {
return
}
s.stateMu.Lock()
defer s.stateMu.Unlock()
s.Client = client
s.upstream = upstream
s.online = client != nil
}
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
if s == nil {
return nil, nil
}
s.stateMu.Lock()
defer s.stateMu.Unlock()
client := s.Client
upstream := s.upstream
s.Client = nil
s.upstream = nil
s.online = false
return client, upstream
}
func (s *StarSSH) takeKeepaliveHandles() (chan struct{}, chan struct{}) {
if s == nil {
return nil, nil
}
s.keepaliveMu.Lock()
defer s.keepaliveMu.Unlock()
stop := s.keepaliveStop
done := s.keepaliveDone
s.keepaliveStop = nil
s.keepaliveDone = nil
return stop, done
}
func (s *StarSSH) closeTransport(waitKeepalive bool) error {
if s == nil {
return nil
}
_ = s.closeReusableSFTPClient()
client, upstream := s.detachTransport()
stop, done := s.takeKeepaliveHandles()
if stop != nil {
close(stop)
}
var closeErr error
if client != nil {
closeErr = normalizeAlreadyClosedError(closeSSHClient(client))
}
if waitKeepalive && done != nil {
<-done
}
if upstreamErr := closeUpstream(upstream); closeErr == nil {
closeErr = upstreamErr
}
return closeErr
}

431
terminal.go Normal file
View File

@ -0,0 +1,431 @@
package starssh
import (
"context"
"errors"
"io"
"sync"
"golang.org/x/crypto/ssh"
)
func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) {
session, err := s.NewPTYSession(config)
if err != nil {
return nil, err
}
stdin, err := session.StdinPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stdout, err := session.StdoutPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stderr, err := session.StderrPipe()
if err != nil {
_ = session.Close()
return nil, err
}
if err := session.Shell(); err != nil {
_ = session.Close()
return nil, err
}
return &TerminalSession{
Session: session,
stdin: stdin,
stdout: stdout,
stderr: stderr,
runDone: make(chan struct{}),
waitDone: make(chan struct{}),
}, nil
}
func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) {
if t == nil {
return
}
t.attachMu.Lock()
defer t.attachMu.Unlock()
t.in = stdin
t.out = stdout
t.errOut = stderr
}
func (t *TerminalSession) StdinWriter() io.Writer {
if t == nil {
return nil
}
return t.stdin
}
func (t *TerminalSession) StdoutReader() io.Reader {
if t == nil {
return nil
}
return t.stdout
}
func (t *TerminalSession) StderrReader() io.Reader {
if t == nil {
return nil
}
return t.stderr
}
func (t *TerminalSession) Write(data []byte) (int, error) {
if t == nil || t.stdin == nil {
return 0, errors.New("terminal stdin is not initialized")
}
return t.stdin.Write(data)
}
func (t *TerminalSession) SendControl(control TerminalControl) error {
_, err := t.Write([]byte{byte(control)})
return err
}
func (t *TerminalSession) Interrupt() error {
return t.SendControl(TerminalControlInterrupt)
}
func (t *TerminalSession) Signal(sig ssh.Signal) error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
if sig == "" {
return errors.New("signal is empty")
}
return t.Session.Signal(sig)
}
func (t *TerminalSession) Run(ctx context.Context) error {
if t == nil {
return errors.New("terminal session is nil")
}
if ctx == nil {
ctx = context.Background()
}
t.runOnce.Do(func() {
t.runErr = t.run(ctx)
close(t.runDone)
})
<-t.runDone
return t.runErr
}
func (t *TerminalSession) run(ctx context.Context) error {
if t.Session == nil {
return errors.New("terminal session is not initialized")
}
t.attachMu.RLock()
in := t.in
out := t.out
errOut := t.errOut
t.attachMu.RUnlock()
if out == nil {
out = io.Discard
}
if errOut == nil {
errOut = out
}
inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in)
if err != nil {
return err
}
defer cancelInput()
var copyWG sync.WaitGroup
doneCopy := make(chan struct{})
copyWG.Add(2)
go func() {
defer copyWG.Done()
if t.stdout != nil {
_, _ = io.Copy(out, t.stdout)
}
}()
go func() {
defer copyWG.Done()
if t.stderr != nil {
_, _ = io.Copy(errOut, t.stderr)
}
}()
go func() {
copyWG.Wait()
close(doneCopy)
}()
var doneInput chan struct{}
if inputReader != nil && t.stdin != nil {
doneInput = make(chan struct{})
go func() {
defer close(doneInput)
_, _ = io.Copy(t.stdin, inputReader)
_ = t.stdin.Close()
}()
}
waitInputPump := func() {
if doneInput == nil {
return
}
select {
case <-doneInput:
return
default:
}
if inputCancelable {
cancelInput()
<-doneInput
}
}
type waitResult struct {
info TerminalExitInfo
err error
}
waitCh := make(chan waitResult, 1)
go func() {
info, err := t.WaitResult()
waitCh <- waitResult{info: info, err: err}
}()
select {
case result := <-waitCh:
waitInputPump()
<-doneCopy
if result.err != nil {
return result.err
}
return result.info.CommandError()
case <-ctx.Done():
t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err())
cancelInput()
_ = t.Close()
<-waitCh
waitInputPump()
<-doneCopy
return ctx.Err()
}
}
func (t *TerminalSession) Wait() error {
info, err := t.WaitResult()
if err != nil {
return err
}
return info.CommandError()
}
func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) {
waitErr := t.waitRaw()
info, closeErr := t.snapshotExitState()
if closeErr != nil {
return info, closeErr
}
if waitErr == nil {
return info, nil
}
var exitErr *ssh.ExitError
if errors.As(waitErr, &exitErr) {
return info, nil
}
if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed {
return info, nil
}
return info, waitErr
}
func (t *TerminalSession) ExitInfo() TerminalExitInfo {
if t == nil {
return TerminalExitInfo{}
}
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.exitInfo
}
func (info TerminalExitInfo) Success() bool {
return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == ""
}
func (info TerminalExitInfo) CommandError() error {
if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal {
return nil
}
if info.ExitCode == 0 && info.ExitSignal == "" {
return nil
}
return &ExecExitError{
Status: info.ExitCode,
Signal: info.ExitSignal,
Message: info.ExitMessage,
}
}
func (t *TerminalSession) Resize(columns int, rows int) error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
if columns <= 0 || rows <= 0 {
return errors.New("columns and rows must be > 0")
}
return t.Session.WindowChange(rows, columns)
}
func (t *TerminalSession) Close() error {
if t == nil {
return nil
}
var closeErr error
t.closeOnce.Do(func() {
if t.stdin != nil {
_ = t.stdin.Close()
}
if t.Session != nil {
closeErr = normalizeAlreadyClosedError(t.Session.Close())
}
})
if closeErr != nil {
t.markCloseReason(TerminalCloseReasonTransportError, closeErr)
}
return closeErr
}
func (t *TerminalSession) waitRaw() error {
if t == nil || t.Session == nil {
return errors.New("terminal session is not initialized")
}
t.waitOnce.Do(func() {
go func() {
waitErr := t.Session.Wait()
t.setWaitResult(waitErr)
close(t.waitDone)
}()
})
<-t.waitDone
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.waitErr
}
func (t *TerminalSession) setWaitResult(waitErr error) {
if t == nil {
return
}
t.stateMu.Lock()
defer t.stateMu.Unlock()
t.waitErr = waitErr
t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason)
}
func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) {
if t == nil || reason == TerminalCloseReasonUnknown {
return
}
t.stateMu.Lock()
defer t.stateMu.Unlock()
if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) {
t.closeReason = reason
}
if err != nil && t.closeErr == nil {
t.closeErr = err
}
}
func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) {
if t == nil {
return TerminalExitInfo{}, nil
}
t.stateMu.RLock()
defer t.stateMu.RUnlock()
return t.exitInfo, t.closeErr
}
func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo {
info := TerminalExitInfo{}
if waitErr == nil {
info.Reason = TerminalCloseReasonExit
} else {
var exitErr *ssh.ExitError
switch {
case errors.As(waitErr, &exitErr):
info.ExitCode = exitErr.ExitStatus()
info.ExitSignal = exitErr.Signal()
info.ExitMessage = exitErr.Msg()
if info.ExitSignal != "" {
info.Reason = TerminalCloseReasonSignal
} else {
info.Reason = TerminalCloseReasonExit
}
case normalizeAlreadyClosedError(waitErr) == nil:
info.Reason = TerminalCloseReasonClosed
case errors.Is(waitErr, context.Canceled):
info.Reason = TerminalCloseReasonContextCanceled
case errors.Is(waitErr, context.DeadlineExceeded):
info.Reason = TerminalCloseReasonDeadlineExceeded
default:
info.Reason = TerminalCloseReasonTransportError
}
}
if overrideReason != TerminalCloseReasonUnknown &&
terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) {
info.Reason = overrideReason
}
return info
}
func terminalCloseReasonFromErr(err error) TerminalCloseReason {
switch {
case errors.Is(err, context.DeadlineExceeded):
return TerminalCloseReasonDeadlineExceeded
case errors.Is(err, context.Canceled):
return TerminalCloseReasonContextCanceled
case err != nil:
return TerminalCloseReasonTransportError
default:
return TerminalCloseReasonUnknown
}
}
func terminalCloseReasonPriority(reason TerminalCloseReason) int {
switch reason {
case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded:
return 30
case TerminalCloseReasonTransportError:
return 20
case TerminalCloseReasonClosed:
return 10
case TerminalCloseReasonSignal, TerminalCloseReasonExit:
return 5
default:
return 0
}
}

225
terminal_input.go Normal file
View File

@ -0,0 +1,225 @@
package starssh
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"reflect"
"sync"
"unsafe"
)
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
// Implementations that buffer data should return a source that already includes any prefetched bytes.
type TerminalInputSourceProvider interface {
TerminalInputSource() io.Reader
}
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
type TerminalInputCanceler interface {
TerminalInputCancel() error
}
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
type TerminalInputAdapter struct {
Reader io.Reader
Source io.Reader
Cancel func() error
}
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
if a.Reader == nil {
return 0, io.EOF
}
return a.Reader.Read(p)
}
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
if a.Source != nil {
return a.Source
}
return a.Reader
}
func (a TerminalInputAdapter) TerminalInputCancel() error {
if a.Cancel != nil {
return a.Cancel()
}
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
return closer.Close()
}
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
return closer.Close()
}
return nil
}
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
var cancelOnce sync.Once
wrapCancel := func(fn func()) func() {
return func() {
cancelOnce.Do(fn)
}
}
if provider, ok := in.(TerminalInputSourceProvider); ok {
source := provider.TerminalInputSource()
if source == nil || sameReader(source, in) {
return prepareDirectTerminalInputReader(in, wrapCancel)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
if err != nil {
return nil, nil, false, err
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return prepared, wrapCancel(func() {
cancel()
_ = canceler.TerminalInputCancel()
}), true, nil
}
return prepared, cancel, cancelable, nil
}
return prepareDirectTerminalInputReader(in, wrapCancel)
}
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
switch typed := in.(type) {
case *bufio.Reader:
return prepareBufferedTerminalInputReader(typed)
case *bufio.ReadWriter:
if typed.Reader == nil {
return in, func() {}, false, nil
}
return prepareBufferedTerminalInputReader(typed.Reader)
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return in, wrapCancel(func() {
_ = canceler.TerminalInputCancel()
}), true, nil
}
if file, ok := in.(*os.File); ok {
dup, err := duplicateTerminalInputFile(file)
if err != nil {
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
}
return dup, wrapCancel(func() {
_ = dup.Close()
}), true, nil
}
if closer, ok := in.(io.ReadCloser); ok {
return closer, wrapCancel(func() {
_ = closer.Close()
}), true, nil
}
return in, func() {}, false, nil
}
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
if reader == nil {
return nil, func() {}, false, nil
}
bufferedPrefix, err := snapshotBufferedPrefix(reader)
if err != nil {
return nil, nil, false, err
}
underlying := unwrapBufioReader(reader)
if underlying == nil {
if len(bufferedPrefix) == 0 {
return reader, func() {}, false, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
if err != nil {
return nil, nil, false, err
}
if len(bufferedPrefix) == 0 {
return prepared, cancel, cancelable, nil
}
if prepared == nil {
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
}
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
buffered := reader.Buffered()
if buffered == 0 {
return nil, nil
}
chunk, err := reader.Peek(buffered)
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
}
prefix := append([]byte(nil), chunk...)
if _, err := reader.Discard(len(prefix)); err != nil {
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
}
return prefix, nil
}
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
if reader == nil {
return nil
}
value := reflect.ValueOf(reader)
if value.Kind() != reflect.Pointer || value.IsNil() {
return nil
}
field := value.Elem().FieldByName("rd")
if !field.IsValid() {
return nil
}
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
underlying, ok := underlyingValue.Interface().(io.Reader)
if !ok || underlying == nil || sameReader(underlying, reader) {
return nil
}
return underlying
}
func sameReader(left io.Reader, right io.Reader) bool {
if left == nil || right == nil {
return false
}
leftValue := reflect.ValueOf(left)
rightValue := reflect.ValueOf(right)
if !leftValue.IsValid() || !rightValue.IsValid() {
return false
}
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
return false
}
return leftValue.Pointer() == rightValue.Pointer()
}

View File

@ -0,0 +1,49 @@
package starssh
import (
"bufio"
"io"
"testing"
"time"
)
func TestPrepareTerminalInputReaderAdapterCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
adapter := TerminalInputAdapter{
Reader: bufio.NewReader(reader),
Source: reader,
Cancel: func() error {
return reader.CloseWithError(io.ErrClosedPipe)
},
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(adapter)
if err != nil {
t.Fatalf("prepare adapter input: %v", err)
}
if !cancelable {
t.Fatal("expected adapter-backed input to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected adapter cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking adapter input did not unblock after cancel")
}
}

290
terminal_input_test.go Normal file
View File

@ -0,0 +1,290 @@
package starssh
import (
"bufio"
"bytes"
"io"
"os"
"testing"
"time"
)
type terminalInputProvider struct {
io.Reader
source io.Reader
}
func (p terminalInputProvider) TerminalInputSource() io.Reader {
return p.source
}
type prefixedReadCloser struct {
io.Reader
io.Closer
}
func TestPrepareTerminalInputReaderBufioReaderPreservesBufferedBytes(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("hello world")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
buffered := bufio.NewReaderSize(reader, 4)
peeked, err := buffered.Peek(5)
if err != nil {
t.Fatalf("prime buffer: %v", err)
}
if string(peeked) != "hello" {
t.Fatalf("unexpected buffered prefix: %q", string(peeked))
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected cancelable reader")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read prepared input: %v", err)
}
if string(data) != "hello world" {
t.Fatalf("unexpected prepared input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderBufioReadWriterPreservesBufferedBytes(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("buffered payload")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
readWriter := bufio.NewReadWriter(bufio.NewReaderSize(reader, 8), bufio.NewWriter(io.Discard))
if _, err := readWriter.Reader.Peek(8); err != nil {
t.Fatalf("prime readwriter buffer: %v", err)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
if err != nil {
t.Fatalf("prepare readwriter input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected cancelable readwriter input")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read prepared readwriter input: %v", err)
}
if string(data) != "buffered payload" {
t.Fatalf("unexpected prepared readwriter input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderBufioReaderFallbackKeepsData(t *testing.T) {
buffered := bufio.NewReader(bytes.NewBufferString("abc123"))
if _, err := buffered.Peek(3); err != nil {
t.Fatalf("prime buffer: %v", err)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare fallback input: %v", err)
}
defer cancel()
if cancelable {
t.Fatal("expected non-cancelable reader")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read fallback input: %v", err)
}
if string(data) != "abc123" {
t.Fatalf("unexpected fallback input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderProviderPrefersExplicitSource(t *testing.T) {
reader, writer, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
defer reader.Close()
if _, err := writer.Write([]byte("provider data")); err != nil {
writer.Close()
t.Fatalf("write pipe: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close writer: %v", err)
}
buffered := bufio.NewReader(reader)
prefix, err := buffered.Peek(len("provider data"))
if err != nil {
t.Fatalf("prime buffer: %v", err)
}
source := prefixedReadCloser{
Reader: io.MultiReader(bytes.NewReader(append([]byte(nil), prefix...)), reader),
Closer: reader,
}
provider := terminalInputProvider{
Reader: buffered,
source: source,
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
if err != nil {
t.Fatalf("prepare provider input: %v", err)
}
defer cancel()
if !cancelable {
t.Fatal("expected provider-backed input to be cancelable")
}
data, err := io.ReadAll(prepared)
if err != nil {
t.Fatalf("read provider input: %v", err)
}
if string(data) != "provider data" {
t.Fatalf("unexpected provider input: %q", string(data))
}
}
func TestPrepareTerminalInputReaderProviderCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
buffered := bufio.NewReader(reader)
provider := terminalInputProvider{
Reader: buffered,
source: reader,
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
if !cancelable {
t.Fatal("expected provider-backed input to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking read did not unblock after cancel")
}
}
func TestPrepareTerminalInputReaderBufioReaderCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
buffered := bufio.NewReader(reader)
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
if err != nil {
t.Fatalf("prepare input: %v", err)
}
if !cancelable {
t.Fatal("expected bufio reader to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking bufio reader did not unblock after cancel")
}
}
func TestPrepareTerminalInputReaderBufioReadWriterCancelUnblocksRead(t *testing.T) {
reader, writer := io.Pipe()
defer reader.Close()
defer writer.Close()
readWriter := bufio.NewReadWriter(bufio.NewReader(reader), bufio.NewWriter(io.Discard))
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
if err != nil {
t.Fatalf("prepare readwriter input: %v", err)
}
if !cancelable {
t.Fatal("expected bufio readwriter to be cancelable")
}
done := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, readErr := prepared.Read(buf)
done <- readErr
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case readErr := <-done:
if readErr == nil {
t.Fatal("expected cancel to interrupt blocking read")
}
case <-time.After(time.Second):
t.Fatal("blocking readwriter input did not unblock after cancel")
}
}

21
terminal_input_unix.go Normal file
View File

@ -0,0 +1,21 @@
//go:build !windows
package starssh
import (
"os"
"syscall"
)
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
if file == nil {
return nil, os.ErrInvalid
}
fd, err := syscall.Dup(int(file.Fd()))
if err != nil {
return nil, err
}
syscall.CloseOnExec(fd)
return os.NewFile(uintptr(fd), file.Name()), nil
}

32
terminal_input_windows.go Normal file
View File

@ -0,0 +1,32 @@
//go:build windows
package starssh
import (
"os"
"golang.org/x/sys/windows"
)
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
if file == nil {
return nil, os.ErrInvalid
}
currentProcess := windows.CurrentProcess()
var duplicated windows.Handle
err := windows.DuplicateHandle(
currentProcess,
windows.Handle(file.Fd()),
currentProcess,
&duplicated,
0,
false,
windows.DUPLICATE_SAME_ACCESS,
)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(duplicated), file.Name()), nil
}

336
transport.go Normal file
View File

@ -0,0 +1,336 @@
package starssh
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
if c == nil || c.reader == nil {
return 0, io.EOF
}
return c.reader.Read(p)
}
func resolveDialContext(info LoginInput) DialContextFunc {
if info.DialContext != nil {
return info.DialContext
}
dialer := &net.Dialer{
Timeout: info.Timeout,
}
return dialer.DialContext
}
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
targetAddr := joinHostPort(info.Addr, info.Port)
if info.Jump != nil {
return dialViaJump(ctx, info, targetAddr)
}
dialContext := resolveDialContext(info)
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
if proxyConfig != nil {
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
}
conn, err := dialContext(ctx, "tcp", targetAddr)
return conn, nil, err
}
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
if info.Jump == nil {
return nil, nil, errors.New("jump login info is nil")
}
jumpClient, err := loginWithContext(ctx, *info.Jump)
if err != nil {
return nil, nil, err
}
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
if err != nil {
_ = jumpClient.Close()
return nil, nil, err
}
return conn, jumpClient, nil
}
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
if dialContext == nil {
return nil, nil, errors.New("dial context is nil")
}
if strings.TrimSpace(proxy.Addr) == "" {
return nil, nil, errors.New("proxy address is empty")
}
switch proxy.Type {
case ProxyTypeSOCKS5:
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
case ProxyTypeHTTPConnect:
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
default:
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
}
}
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
if proxy.Username != "" || proxy.Password != "" {
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
request += "Proxy-Authorization: Basic " + token + "\r\n"
}
request += "\r\n"
if _, err := io.WriteString(conn, request); err != nil {
_ = conn.Close()
return nil, err
}
reader := bufio.NewReader(conn)
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = conn.Close()
return nil, err
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 300 {
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
_ = conn.Close()
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
}
if reader.Buffered() == 0 {
return conn, nil
}
return &bufferedConn{Conn: conn, reader: reader}, nil
}
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
methods := []byte{0x00}
useAuth := proxy.Username != "" || proxy.Password != ""
if useAuth {
methods = append(methods, 0x02)
}
hello := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(hello); err != nil {
_ = conn.Close()
return nil, err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
_ = conn.Close()
return nil, err
}
if response[0] != 0x05 {
_ = conn.Close()
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
}
if response[1] == 0xFF {
_ = conn.Close()
return nil, errors.New("socks5 proxy has no acceptable auth method")
}
if response[1] == 0x02 {
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
_ = conn.Close()
return nil, err
}
}
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
_ = conn.Close()
return nil, err
}
if err := readSOCKS5ConnectResponse(conn); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
if len(username) > 255 || len(password) > 255 {
return errors.New("socks5 username/password too long")
}
request := make([]byte, 0, 3+len(username)+len(password))
request = append(request, 0x01, byte(len(username)))
request = append(request, []byte(username)...)
request = append(request, byte(len(password)))
request = append(request, []byte(password)...)
if _, err := conn.Write(request); err != nil {
return err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
return err
}
if response[1] != 0x00 {
return errors.New("socks5 username/password authentication failed")
}
return nil
}
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
host, portString, err := net.SplitHostPort(targetAddr)
if err != nil {
return err
}
port, err := strconv.Atoi(portString)
if err != nil {
return err
}
if port < 0 || port > 65535 {
return fmt.Errorf("invalid port %d", port)
}
request := []byte{0x05, 0x01, 0x00}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
request = append(request, 0x01)
request = append(request, ip4...)
} else {
request = append(request, 0x04)
request = append(request, ip.To16()...)
}
} else {
if len(host) > 255 {
return errors.New("socks5 target host too long")
}
request = append(request, 0x03, byte(len(host)))
request = append(request, []byte(host)...)
}
request = append(request, byte(port>>8), byte(port))
_, err = conn.Write(request)
return err
}
func readSOCKS5ConnectResponse(conn net.Conn) error {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
if header[0] != 0x05 {
return fmt.Errorf("invalid socks5 response version %d", header[0])
}
if header[1] != 0x00 {
return fmt.Errorf("socks5 connect failed with code %d", header[1])
}
switch header[3] {
case 0x01:
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
return err
}
case 0x03:
size := make([]byte, 1)
if _, err := io.ReadFull(conn, size); err != nil {
return err
}
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
return err
}
case 0x04:
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
}
_, err := io.ReadFull(conn, make([]byte, 2))
return err
}
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
if conn == nil {
return func() {}
}
var (
deadline time.Time
hasValue bool
)
if ctx != nil {
if ctxDeadline, ok := ctx.Deadline(); ok {
deadline = ctxDeadline
hasValue = true
}
}
if timeout > 0 {
timeoutDeadline := time.Now().Add(timeout)
if !hasValue || timeoutDeadline.Before(deadline) {
deadline = timeoutDeadline
hasValue = true
}
}
if !hasValue {
return func() {}
}
_ = conn.SetDeadline(deadline)
return func() {
_ = conn.SetDeadline(time.Time{})
}
}
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
if proxy == nil {
return nil
}
normalized := *proxy
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
normalized.Addr = strings.TrimSpace(normalized.Addr)
if normalized.Timeout <= 0 {
normalized.Timeout = defaultTimeout
}
return &normalized
}
func joinHostPort(host string, port int) string {
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
}

196
types.go Normal file
View File

@ -0,0 +1,196 @@
package starssh
import (
"bufio"
"context"
"io"
"net"
"sync"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
defaultSSHPort = 22
defaultLoginTimeout = 5 * time.Second
defaultKeepAliveTimeout = 3 * time.Second
defaultShellPollInterval = 120 * time.Millisecond
defaultShellSetupDelay = 200 * time.Millisecond
defaultShellSetupTimeout = 3 * time.Second
defaultShellWaitTimeout = 30 * time.Second
defaultShellPromptToken = "__STARSSH_PROMPT__>"
defaultPTYTerm = "xterm"
defaultPTYRows = 500
defaultPTYColumns = 250
defaultTransferBufferSize = 1024 * 1024
defaultExecStreamMaxPendingChunks = 256
defaultExecStreamMaxPendingBytes = 4 * 1024 * 1024
)
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
type ProxyType string
const (
ProxyTypeSOCKS5 ProxyType = "socks5"
ProxyTypeHTTPConnect ProxyType = "http_connect"
)
type ProxyConfig struct {
Type ProxyType
Addr string
Username string
Password string
Timeout time.Duration
}
type AuthMethodKind string
const (
AuthMethodPrivateKey AuthMethodKind = "private_key"
AuthMethodPassword AuthMethodKind = "password"
AuthMethodKeyboardInteractive AuthMethodKind = "keyboard_interactive"
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
)
type StarSSH struct {
stateMu sync.RWMutex
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
upstream *StarSSH
sftpClient *sftp.Client
sftpMu sync.Mutex
keepaliveMu sync.Mutex
keepaliveStop chan struct{}
keepaliveDone chan struct{}
}
type LoginInput struct {
KeyExchanges []string
Ciphers []string
MACs []string
User string
Password string
PasswordCallback func() (string, error)
KeyboardInteractiveCallback ssh.KeyboardInteractiveChallenge
Prikey string
PrikeyPwd string
DisableSSHAgent bool
AuthOrder []AuthMethodKind
Addr string
Port int
Timeout time.Duration
DialContext DialContextFunc
Proxy *ProxyConfig
Jump *LoginInput
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
BannerCallback func(string) error
}
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
// It is not a generic cross-shell abstraction; for product-grade interactive terminals, prefer TerminalSession.
type StarShell struct {
Keyword string
UseWaitDefault bool
WaitTimeout time.Duration
Session *ssh.Session
in io.Writer
out *bufio.Reader
er *bufio.Reader
outbyte []byte
errbyte []byte
lastout int64
errors error
isprint bool
isfuncs bool
iscolor bool
isecho bool
rw sync.RWMutex
funcs func(string)
writeMu sync.Mutex
commandMu sync.Mutex
closeOnce sync.Once
promptToken string
}
type TerminalConfig struct {
Term string
Rows int
Columns int
Modes ssh.TerminalModes
}
type TerminalControl byte
const (
TerminalControlInterrupt TerminalControl = 0x03
TerminalControlEOF TerminalControl = 0x04
TerminalControlBell TerminalControl = 0x07
TerminalControlBackspace TerminalControl = 0x08
TerminalControlLineKill TerminalControl = 0x15
TerminalControlQuit TerminalControl = 0x1c
TerminalControlSuspend TerminalControl = 0x1a
TerminalControlPauseOutput TerminalControl = 0x13
TerminalControlResumeOutput TerminalControl = 0x11
)
type TerminalCloseReason string
const (
TerminalCloseReasonUnknown TerminalCloseReason = ""
TerminalCloseReasonExit TerminalCloseReason = "exit"
TerminalCloseReasonSignal TerminalCloseReason = "signal"
TerminalCloseReasonClosed TerminalCloseReason = "closed"
TerminalCloseReasonContextCanceled TerminalCloseReason = "context_canceled"
TerminalCloseReasonDeadlineExceeded TerminalCloseReason = "deadline_exceeded"
TerminalCloseReasonTransportError TerminalCloseReason = "transport_error"
)
type TerminalExitInfo struct {
ExitCode int
ExitSignal string
ExitMessage string
Reason TerminalCloseReason
}
type TerminalSession struct {
Session *ssh.Session
ID string
Label string
Metadata map[string]string
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
attachMu sync.RWMutex
in io.Reader
out io.Writer
errOut io.Writer
runOnce sync.Once
runDone chan struct{}
runErr error
waitOnce sync.Once
waitDone chan struct{}
stateMu sync.RWMutex
waitErr error
exitInfo TerminalExitInfo
closeReason TerminalCloseReason
closeErr error
closeOnce sync.Once
}

162
utils.go Normal file
View File

@ -0,0 +1,162 @@
package starssh
import (
"crypto/rand"
"encoding/hex"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
var (
ansiCSIRegexp = regexp.MustCompile(`\x1b\[[0-9;?]*[ -/]*[@-~]`)
ansiOSCRegexp = regexp.MustCompile(`\x1b\][^\x07]*(\x07|\x1b\\)`)
leadingIntRegexp = regexp.MustCompile(`^[+-]?\d+`)
)
func SedColor(str string) string {
return stripControlSequences(str)
}
func normalizeShellOutput(raw string) string {
return strings.TrimSpace(strings.ReplaceAll(raw, "\r\n", "\n"))
}
func stripControlSequences(raw string) string {
cleaned := ansiOSCRegexp.ReplaceAllString(raw, "")
cleaned = ansiCSIRegexp.ReplaceAllString(cleaned, "")
cleaned = strings.ReplaceAll(cleaned, "\r", "")
cleaned = strings.Map(func(r rune) rune {
if r == '\n' || r == '\t' {
return r
}
if r < 0x20 || r == 0x7f {
return -1
}
return r
}, cleaned)
return cleaned
}
func stripLeadingEcho(output string, command string, markerCommand string) string {
result := output
if command == "" {
return result
}
withMarker := command
if markerCommand != "" {
withMarker += "\n" + markerCommand
}
if strings.HasPrefix(result, withMarker) {
return strings.TrimSpace(strings.TrimPrefix(result, withMarker))
}
if strings.HasPrefix(result, command) {
return strings.TrimSpace(strings.TrimPrefix(result, command))
}
return result
}
func collectLinesWithoutTokens(output string, tokens ...string) string {
lines := strings.Split(output, "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
skip := false
for _, token := range tokens {
if token != "" && strings.Contains(line, token) {
skip = true
break
}
}
if !skip {
filtered = append(filtered, line)
}
}
return strings.TrimSpace(strings.Join(filtered, "\n"))
}
func collectLinesForCommandOutput(output string, promptToken string, tokens ...string) string {
lines := strings.Split(output, "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
skip := false
for _, token := range tokens {
if token != "" && strings.Contains(line, token) {
skip = true
break
}
}
if !skip && promptToken != "" && strings.Contains(line, promptToken) {
skip = true
}
if !skip {
filtered = append(filtered, line)
}
}
return strings.TrimSpace(strings.Join(filtered, "\n"))
}
func shellSingleQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
func normalizeBufferSize(bufcap int) int {
if bufcap <= 0 {
return defaultTransferBufferSize
}
return bufcap
}
func newCommandTokens() (beginToken string, endToken string) {
nonce := newNonce(8)
return "__STARSSH_BEGIN_" + nonce + "__", "__STARSSH_END_" + nonce + "__"
}
func newNonce(size int) string {
if size <= 0 {
size = 8
}
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return strings.ToUpper(hex.EncodeToString(buf))
}
func splitByEndToken(output string, endToken string) (before string, exitCode int, found bool, parseErr error) {
prefix := endToken + ":"
lines := strings.Split(output, "\n")
beforeLines := make([]string, 0, len(lines))
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if !strings.HasPrefix(trimmedLine, prefix) {
beforeLines = append(beforeLines, line)
continue
}
codeText := strings.TrimSpace(trimmedLine[len(prefix):])
match := leadingIntRegexp.FindString(codeText)
if match == "" || match != codeText {
beforeLines = append(beforeLines, line)
continue
}
code, err := strconv.Atoi(match)
if err != nil {
beforeLines = append(beforeLines, line)
continue
}
return strings.Join(beforeLines, "\n"), code, true, nil
}
return strings.Join(beforeLines, "\n"), 0, false, nil
}