refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力
- 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
This commit is contained in:
parent
d6fbea8468
commit
f20eb653ae
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
.sentrux/
|
||||
agent_readme.md
|
||||
target.md
|
||||
.gocache/
|
||||
.tmp_*/
|
||||
.codex/
|
||||
.idea/
|
||||
agents.md
|
||||
.codex
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
172
cancel_semantics_test.go
Normal file
172
cancel_semantics_test.go
Normal file
@ -0,0 +1,172 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(client, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := star.PingContext(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 0 {
|
||||
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != client {
|
||||
t.Fatal("expected ssh client to remain attached after PingContext cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := star.PingContextCloseOnCancel(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(client, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil connection on canceled dial")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 0 {
|
||||
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != client {
|
||||
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil connection on canceled dial")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
|
||||
}
|
||||
}
|
||||
69
exec_legacy_test.go
Normal file
69
exec_legacy_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExistsFallsBackFromPOSIXToPowerShell(t *testing.T) {
|
||||
oldExecRequestRunner := execRequestRunner
|
||||
t.Cleanup(func() {
|
||||
execRequestRunner = oldExecRequestRunner
|
||||
})
|
||||
|
||||
var calls []ExecRequest
|
||||
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
calls = append(calls, req)
|
||||
switch len(calls) {
|
||||
case 1:
|
||||
return &ExecResult{ExitCode: 127, Stderr: []byte("sh not found")}, nil
|
||||
case 2:
|
||||
return &ExecResult{Stdout: []byte("1\n")}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected extra probe request: %+v", req)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
if !star.Exists(`C:\Windows\System32`) {
|
||||
t.Fatal("expected helper to succeed after PowerShell fallback")
|
||||
}
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("expected two probe attempts, got %d", len(calls))
|
||||
}
|
||||
if calls[0].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(calls[0].Command, "sh -lc ") {
|
||||
t.Fatalf("unexpected first probe request: %+v", calls[0])
|
||||
}
|
||||
if calls[1].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(strings.ToLower(calls[1].Command), "powershell.exe ") {
|
||||
t.Fatalf("unexpected second probe request: %+v", calls[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserFallsBackToCMDWhenPowerShellVariantsFail(t *testing.T) {
|
||||
oldExecRequestRunner := execRequestRunner
|
||||
t.Cleanup(func() {
|
||||
execRequestRunner = oldExecRequestRunner
|
||||
})
|
||||
|
||||
var calls []ExecRequest
|
||||
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
calls = append(calls, req)
|
||||
if len(calls) < 5 {
|
||||
return &ExecResult{ExitCode: 127, Stderr: []byte("command not found")}, nil
|
||||
}
|
||||
return &ExecResult{Stdout: []byte("HOST\\tester\r\n")}, nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
if got := star.GetUser(); got != `HOST\tester` {
|
||||
t.Fatalf("unexpected user after fallback: %q", got)
|
||||
}
|
||||
if len(calls) != 5 {
|
||||
t.Fatalf("expected five probe attempts, got %d", len(calls))
|
||||
}
|
||||
if got := strings.ToLower(calls[4].Command); got != "cmd.exe /q /d /c whoami" {
|
||||
t.Fatalf("unexpected final fallback command: %q", calls[4].Command)
|
||||
}
|
||||
}
|
||||
694
forward.go
Normal file
694
forward.go
Normal file
@ -0,0 +1,694 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ForwardRequest struct {
|
||||
ListenAddr string
|
||||
TargetAddr string
|
||||
DialContext DialContextFunc
|
||||
}
|
||||
|
||||
type DynamicForwardRequest struct {
|
||||
ListenAddr string
|
||||
}
|
||||
|
||||
type PortForwarder struct {
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
acceptDone chan struct{}
|
||||
|
||||
connWG sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
|
||||
connMu sync.Mutex
|
||||
conns map[net.Conn]struct{}
|
||||
|
||||
errMu sync.Mutex
|
||||
err error
|
||||
|
||||
cleanupOnce sync.Once
|
||||
cleanupFns []func() error
|
||||
}
|
||||
|
||||
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
return client.Dial(network, address)
|
||||
}
|
||||
|
||||
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return LoginContext(ctx, input)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) {
|
||||
return s.DialTCPContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return s.dialTCPContext(ctx, network, address, nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return s.dialTCPContext(ctx, network, address, s.Close)
|
||||
}
|
||||
|
||||
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if strings.TrimSpace(network) == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
if strings.TrimSpace(address) == "" {
|
||||
return nil, errors.New("forward address is empty")
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runCancel := func() {}
|
||||
if onCancel != nil {
|
||||
var cancelOnce sync.Once
|
||||
runCancel = func() {
|
||||
cancelOnce.Do(func() {
|
||||
_ = onCancel()
|
||||
})
|
||||
}
|
||||
|
||||
cancelDone := make(chan struct{})
|
||||
defer close(cancelDone)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
case <-cancelDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
dialFunc := dialSSHClient
|
||||
resultCh := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialFunc(ctx, client, network, address)
|
||||
if ctx.Err() != nil && conn != nil {
|
||||
_ = conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case resultCh <- dialResult{conn: conn, err: err}:
|
||||
default:
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.conn, result.err
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("local forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("local forward target address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("local forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("local forward target address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(func() error {
|
||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||
})
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("remote forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("remote forward target address is empty")
|
||||
}
|
||||
|
||||
listener, err := client.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := req.DialContext
|
||||
if dialContext == nil {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultLoginTimeout,
|
||||
}
|
||||
dialContext = dialer.DialContext
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return dialContext(ctx, "tcp", req.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("dynamic forward listen address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
||||
return s.DialTCPContext(ctx, "tcp", targetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("dynamic forward listen address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(func() error {
|
||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||
})
|
||||
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
||||
return forwardClient.DialTCPContext(ctx, "tcp", targetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Addr() net.Addr {
|
||||
if f == nil || f.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return f.listener.Addr()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Wait() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
<-f.acceptDone
|
||||
f.connWG.Wait()
|
||||
f.runCleanup()
|
||||
return f.Err()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Err() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
f.errMu.Lock()
|
||||
defer f.errMu.Unlock()
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Close() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
f.closeOnce.Do(func() {
|
||||
if f.cancel != nil {
|
||||
f.cancel()
|
||||
}
|
||||
if f.listener != nil {
|
||||
closeErr = normalizeAlreadyClosedError(f.listener.Close())
|
||||
}
|
||||
f.closeActiveConnections()
|
||||
})
|
||||
|
||||
<-f.acceptDone
|
||||
f.connWG.Wait()
|
||||
f.runCleanup()
|
||||
if closeErr != nil {
|
||||
return closeErr
|
||||
}
|
||||
return f.Err()
|
||||
}
|
||||
|
||||
func newPortForwarder(listener net.Listener) *PortForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &PortForwarder{
|
||||
listener: listener,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
acceptDone: make(chan struct{}),
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
return newDetachedForwardClient(ctx, s.LoginInfo)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) addCleanup(fn func() error) {
|
||||
if f == nil || fn == nil {
|
||||
return
|
||||
}
|
||||
f.cleanupFns = append(f.cleanupFns, fn)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) runCleanup() {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.cleanupOnce.Do(func() {
|
||||
for _, fn := range f.cleanupFns {
|
||||
f.setError(normalizeAlreadyClosedError(fn()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) {
|
||||
go func() {
|
||||
defer close(f.acceptDone)
|
||||
|
||||
for {
|
||||
conn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
if isClosedListenerError(err) {
|
||||
return
|
||||
}
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.trackConn(conn)
|
||||
f.connWG.Add(1)
|
||||
go func(src net.Conn) {
|
||||
defer f.connWG.Done()
|
||||
defer f.untrackConn(src)
|
||||
|
||||
dst, err := targetDial(f.ctx)
|
||||
if err != nil {
|
||||
f.setError(err)
|
||||
_ = src.Close()
|
||||
return
|
||||
}
|
||||
f.trackConn(dst)
|
||||
defer f.untrackConn(dst)
|
||||
|
||||
f.setError(pipeForwardConnections(src, dst))
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) {
|
||||
go func() {
|
||||
defer close(f.acceptDone)
|
||||
|
||||
for {
|
||||
conn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
if isClosedListenerError(err) {
|
||||
return
|
||||
}
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.trackConn(conn)
|
||||
f.connWG.Add(1)
|
||||
go func(src net.Conn) {
|
||||
defer f.connWG.Done()
|
||||
defer f.untrackConn(src)
|
||||
if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil {
|
||||
f.setError(err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) setError(err error) {
|
||||
if f == nil || err == nil || f.shouldIgnoreError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
f.errMu.Lock()
|
||||
defer f.errMu.Unlock()
|
||||
if f.err == nil {
|
||||
f.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func pipeForwardConnections(left net.Conn, right net.Conn) error {
|
||||
if left == nil || right == nil {
|
||||
if left != nil {
|
||||
_ = left.Close()
|
||||
}
|
||||
if right != nil {
|
||||
_ = right.Close()
|
||||
}
|
||||
return errors.New("forward connection endpoint is nil")
|
||||
}
|
||||
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
var copyWG sync.WaitGroup
|
||||
errCh := make(chan error, 2)
|
||||
copyWG.Add(2)
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
_, err := io.Copy(right, left)
|
||||
errCh <- normalizeAlreadyClosedError(err)
|
||||
closeWrite(right)
|
||||
}()
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
_, err := io.Copy(left, right)
|
||||
errCh <- normalizeAlreadyClosedError(err)
|
||||
closeWrite(left)
|
||||
}()
|
||||
copyWG.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeWrite(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
if writer, ok := conn.(closeWriter); ok {
|
||||
_ = writer.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
func handleDynamicForwardConn(
|
||||
ctx context.Context,
|
||||
src net.Conn,
|
||||
targetDial func(context.Context, string) (net.Conn, error),
|
||||
trackConn func(net.Conn),
|
||||
untrackConn func(net.Conn),
|
||||
) error {
|
||||
if src == nil {
|
||||
return errors.New("dynamic forward source connection is nil")
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if err := negotiateSOCKS5NoAuth(src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetAddr, replyCode, err := readSOCKS5ConnectTarget(src)
|
||||
if err != nil {
|
||||
_ = writeSOCKS5ServerReply(src, replyCode, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
dst, err := targetDial(ctx, targetAddr)
|
||||
if err != nil {
|
||||
_ = writeSOCKS5ServerReply(src, 0x01, nil)
|
||||
return err
|
||||
}
|
||||
if trackConn != nil {
|
||||
trackConn(dst)
|
||||
defer untrackConn(dst)
|
||||
}
|
||||
|
||||
if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil {
|
||||
_ = dst.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return pipeForwardConnections(src, dst)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) trackConn(conn net.Conn) {
|
||||
if f == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
f.connMu.Lock()
|
||||
defer f.connMu.Unlock()
|
||||
f.conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
func (f *PortForwarder) untrackConn(conn net.Conn) {
|
||||
if f == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
f.connMu.Lock()
|
||||
defer f.connMu.Unlock()
|
||||
delete(f.conns, conn)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) closeActiveConnections() {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.connMu.Lock()
|
||||
conns := make([]net.Conn, 0, len(f.conns))
|
||||
for conn := range f.conns {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
f.connMu.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *PortForwarder) shouldIgnoreError(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if normalizeAlreadyClosedError(err) == nil {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func negotiateSOCKS5NoAuth(conn net.Conn) error {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return errors.New("invalid socks5 version")
|
||||
}
|
||||
|
||||
methodCount := int(header[1])
|
||||
methods := make([]byte, methodCount)
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := byte(0xFF)
|
||||
for _, candidate := range methods {
|
||||
if candidate == 0x00 {
|
||||
method = 0x00
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte{0x05, method}); err != nil {
|
||||
return err
|
||||
}
|
||||
if method == 0xFF {
|
||||
return errors.New("socks5 client does not support no-auth method")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) {
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", 0x01, err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return "", 0x01, errors.New("invalid socks5 request version")
|
||||
}
|
||||
if header[1] != 0x01 {
|
||||
return "", 0x07, errors.New("unsupported socks5 command")
|
||||
}
|
||||
|
||||
host, err := readSOCKS5RequestHost(conn, header[3])
|
||||
if err != nil {
|
||||
return "", 0x08, err
|
||||
}
|
||||
|
||||
portBytes := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBytes); err != nil {
|
||||
return "", 0x01, err
|
||||
}
|
||||
port := int(portBytes[0])<<8 | int(portBytes[1])
|
||||
return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil
|
||||
}
|
||||
|
||||
func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) {
|
||||
switch addressType {
|
||||
case 0x01:
|
||||
buffer := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(buffer).String(), nil
|
||||
case 0x03:
|
||||
size := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, size); err != nil {
|
||||
return "", err
|
||||
}
|
||||
buffer := make([]byte, int(size[0]))
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buffer), nil
|
||||
case 0x04:
|
||||
buffer := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(buffer).String(), nil
|
||||
default:
|
||||
return "", errors.New("unsupported socks5 address type")
|
||||
}
|
||||
}
|
||||
|
||||
func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error {
|
||||
reply := []byte{0x05, replyCode, 0x00}
|
||||
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil {
|
||||
if ip4 := tcpAddr.IP.To4(); ip4 != nil {
|
||||
reply = append(reply, 0x01)
|
||||
reply = append(reply, ip4...)
|
||||
} else if ip16 := tcpAddr.IP.To16(); ip16 != nil {
|
||||
reply = append(reply, 0x04)
|
||||
reply = append(reply, ip16...)
|
||||
}
|
||||
if len(reply) > 3 {
|
||||
port := tcpAddr.Port
|
||||
reply = append(reply, byte(port>>8), byte(port))
|
||||
_, err := conn.Write(reply)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
_, err := conn.Write(reply)
|
||||
return err
|
||||
}
|
||||
|
||||
func isClosedListenerError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
164
forward_test.go
Normal file
164
forward_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
newDetachedForwardClient = oldNewDetachedForwardClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
var detachedCalls atomic.Int32
|
||||
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
detachedCalls.Add(1)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != baseClient {
|
||||
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
|
||||
}
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
t.Fatal("default local forward should not close the main ssh client")
|
||||
return nil
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
TargetAddr: "example.internal:22",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start local forward: %v", err)
|
||||
}
|
||||
defer forwarder.Close()
|
||||
|
||||
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
|
||||
if string(reply) != "ping" {
|
||||
t.Fatalf("unexpected forwarded reply: %q", string(reply))
|
||||
}
|
||||
if detachedCalls.Load() != 0 {
|
||||
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
newDetachedForwardClient = oldNewDetachedForwardClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
detachedClient := &ssh.Client{}
|
||||
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
forwardClient := &StarSSH{}
|
||||
forwardClient.setTransport(detachedClient, nil)
|
||||
|
||||
var detachedCalls atomic.Int32
|
||||
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
detachedCalls.Add(1)
|
||||
return forwardClient, nil
|
||||
}
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != detachedClient {
|
||||
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
|
||||
}
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
TargetAddr: "example.internal:22",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start detached local forward: %v", err)
|
||||
}
|
||||
|
||||
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
|
||||
if string(reply) != "pong" {
|
||||
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
|
||||
}
|
||||
|
||||
if err := forwarder.Close(); err != nil {
|
||||
t.Fatalf("close detached local forward: %v", err)
|
||||
}
|
||||
if detachedCalls.Load() != 1 {
|
||||
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != baseClient {
|
||||
t.Fatal("detached local forward should not detach the main ssh client")
|
||||
}
|
||||
if got := forwardClient.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("detached local forward should close its detached ssh client")
|
||||
}
|
||||
}
|
||||
|
||||
func echoForwardPipe(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write(buf[:n])
|
||||
}
|
||||
|
||||
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial forward listener: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write forwarded payload: %v", err)
|
||||
}
|
||||
|
||||
reply := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read forwarded reply: %v", err)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
15
go.mod
15
go.mod
@ -1,8 +1,17 @@
|
||||
module b612.me/starssh
|
||||
|
||||
go 1.16
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/pkg/sftp v1.13.4
|
||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000
|
||||
github.com/Microsoft/go-winio v0.6.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/sys v0.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
)
|
||||
|
||||
93
go.sum
93
go.sum
@ -1,29 +1,92 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/pkg/sftp v1.13.4 h1:Lb0RYJCmgUcBgZosfoi9Y9sbl6+LJgOIgk/2Y4YjMFg=
|
||||
github.com/pkg/sftp v1.13.4/go.mod h1:LzqnAvaD5TWeNBsZpfKxSYn1MbjWwOsCIAFFJbpIsK8=
|
||||
github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw=
|
||||
github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
|
||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
290
hostkey.go
Normal file
290
hostkey.go
Normal file
@ -0,0 +1,290 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
)
|
||||
|
||||
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
|
||||
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
|
||||
|
||||
type HostKeyFingerprintMismatchError struct {
|
||||
Expected []string
|
||||
ActualSHA256 string
|
||||
ActualLegacyMD5 string
|
||||
}
|
||||
|
||||
type AcceptNewHostKeyOptions struct {
|
||||
KnownHostsFile string
|
||||
HashHosts bool
|
||||
IncludeRemoteAddress bool
|
||||
FileMode os.FileMode
|
||||
}
|
||||
|
||||
func (e *HostKeyFingerprintMismatchError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
|
||||
}
|
||||
|
||||
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
trimmed := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
file = strings.TrimSpace(file)
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
trimmed = append(trimmed, file)
|
||||
}
|
||||
if len(trimmed) == 0 {
|
||||
return nil, ErrKnownHostsFileRequired
|
||||
}
|
||||
return knownhosts.New(trimmed...)
|
||||
}
|
||||
|
||||
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
|
||||
KnownHostsFile: file,
|
||||
IncludeRemoteAddress: true,
|
||||
})
|
||||
}
|
||||
|
||||
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
options = normalizeAcceptNewHostKeyOptions(options)
|
||||
if options.KnownHostsFile == "" {
|
||||
return nil, ErrKnownHostsFileRequired
|
||||
}
|
||||
|
||||
state := &acceptNewHostKeyState{
|
||||
file: options.KnownHostsFile,
|
||||
hashHosts: options.HashHosts,
|
||||
includeRemoteAddress: options.IncludeRemoteAddress,
|
||||
fileMode: options.FileMode,
|
||||
}
|
||||
if err := state.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return state.checkHostKey, nil
|
||||
}
|
||||
|
||||
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
normalized := make([]string, 0, len(fingerprints))
|
||||
seen := make(map[string]struct{}, len(fingerprints))
|
||||
for _, raw := range fingerprints {
|
||||
fingerprint, err := normalizeHostKeyFingerprint(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fingerprint == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[fingerprint]; exists {
|
||||
continue
|
||||
}
|
||||
seen[fingerprint] = struct{}{}
|
||||
normalized = append(normalized, fingerprint)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil, ErrHostFingerprintRequired
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
actualSHA256 := ssh.FingerprintSHA256(key)
|
||||
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
|
||||
for _, want := range normalized {
|
||||
if want == actualSHA256 || want == actualMD5 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &HostKeyFingerprintMismatchError{
|
||||
Expected: append([]string(nil), normalized...),
|
||||
ActualSHA256: actualSHA256,
|
||||
ActualLegacyMD5: actualMD5,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type acceptNewHostKeyState struct {
|
||||
file string
|
||||
hashHosts bool
|
||||
includeRemoteAddress bool
|
||||
fileMode os.FileMode
|
||||
|
||||
mu sync.Mutex
|
||||
cb ssh.HostKeyCallback
|
||||
}
|
||||
|
||||
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cb != nil {
|
||||
err := s.cb(hostname, remote, key)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var keyErr *knownhosts.KeyError
|
||||
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.reload(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cb == nil {
|
||||
return errors.New("known_hosts callback is nil after reload")
|
||||
}
|
||||
return s.cb(hostname, remote, key)
|
||||
}
|
||||
|
||||
func (s *acceptNewHostKeyState) reload() error {
|
||||
callback, err := loadKnownHostsCallback(s.file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cb = callback
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
|
||||
_, err := os.Stat(file)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return knownhosts.New(file)
|
||||
}
|
||||
|
||||
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
|
||||
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
|
||||
if options.FileMode == 0 {
|
||||
options.FileMode = 0o600
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
|
||||
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
|
||||
if len(addresses) == 0 {
|
||||
return "", errors.New("no hostname or remote address available for known_hosts entry")
|
||||
}
|
||||
|
||||
patterns := make([]string, 0, len(addresses))
|
||||
for _, address := range addresses {
|
||||
normalized := knownhosts.Normalize(address)
|
||||
if hashHosts {
|
||||
patterns = append(patterns, knownhosts.HashHostname(normalized))
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, normalized)
|
||||
}
|
||||
|
||||
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
|
||||
return strings.Join(patterns, ",") + " " + authorizedKey, nil
|
||||
}
|
||||
|
||||
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
|
||||
addresses := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
|
||||
add := func(address string) {
|
||||
address = strings.TrimSpace(address)
|
||||
if address == "" {
|
||||
return
|
||||
}
|
||||
normalized := knownhosts.Normalize(address)
|
||||
if _, exists := seen[normalized]; exists {
|
||||
return
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
|
||||
add(hostname)
|
||||
if includeRemoteAddress && remote != nil {
|
||||
add(remote.String())
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
|
||||
if strings.TrimSpace(file) == "" {
|
||||
return ErrKnownHostsFileRequired
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return errors.New("known_hosts line is empty")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(file)
|
||||
if dir != "." && dir != "" {
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer handle.Close()
|
||||
|
||||
if _, err := handle.WriteString(line + "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return handle.Chmod(mode)
|
||||
}
|
||||
|
||||
func normalizeHostKeyFingerprint(raw string) (string, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
|
||||
suffix := strings.TrimSpace(value[len("SHA256:"):])
|
||||
if suffix == "" {
|
||||
return "", ErrHostFingerprintRequired
|
||||
}
|
||||
return "SHA256:" + suffix, nil
|
||||
}
|
||||
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||||
suffix := strings.TrimSpace(value[len("MD5:"):])
|
||||
if suffix == "" {
|
||||
return "", ErrHostFingerprintRequired
|
||||
}
|
||||
return normalizeMD5Fingerprint("MD5:" + suffix), nil
|
||||
}
|
||||
if strings.Count(value, ":") >= 2 {
|
||||
return normalizeMD5Fingerprint("MD5:" + value), nil
|
||||
}
|
||||
return "SHA256:" + value, nil
|
||||
}
|
||||
|
||||
func normalizeMD5Fingerprint(value string) string {
|
||||
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||||
return "MD5:" + strings.ToLower(value)
|
||||
}
|
||||
return "MD5:" + strings.ToLower(value[len("MD5:"):])
|
||||
}
|
||||
135
keepalive.go
Normal file
135
keepalive.go
Normal file
@ -0,0 +1,135 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarSSH) Ping() error {
|
||||
return s.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (s *StarSSH) PingContext(ctx context.Context) error {
|
||||
return s.pingContext(ctx, nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) PingContextCloseOnCancel(ctx context.Context) error {
|
||||
return s.pingContext(ctx, s.Close)
|
||||
}
|
||||
|
||||
func (s *StarSSH) pingContext(ctx context.Context, onCancel func() error) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runCancel := func() {}
|
||||
if onCancel != nil {
|
||||
var cancelOnce sync.Once
|
||||
runCancel = func() {
|
||||
cancelOnce.Do(func() {
|
||||
_ = onCancel()
|
||||
})
|
||||
}
|
||||
|
||||
cancelDone := make(chan struct{})
|
||||
defer close(cancelDone)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
case <-cancelDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
requestFunc := sendKeepAliveRequest
|
||||
pingErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := requestFunc(ctx, client)
|
||||
select {
|
||||
case pingErr <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-pingErr:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) startAutoKeepAlive() {
|
||||
if s == nil || s.snapshotSSHClient() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
interval := s.LoginInfo.KeepAliveInterval
|
||||
if interval <= 0 {
|
||||
return
|
||||
}
|
||||
timeout := s.LoginInfo.KeepAliveTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultKeepAliveTimeout
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
|
||||
s.keepaliveMu.Lock()
|
||||
if s.keepaliveStop != nil {
|
||||
s.keepaliveMu.Unlock()
|
||||
return
|
||||
}
|
||||
s.keepaliveStop = stop
|
||||
s.keepaliveDone = done
|
||||
s.keepaliveMu.Unlock()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
err := s.pingContext(pingCtx, nil)
|
||||
cancel()
|
||||
if err != nil {
|
||||
s.closeFromKeepAlive()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *StarSSH) stopAutoKeepAlive() {
|
||||
stop, done := s.takeKeepaliveHandles()
|
||||
if stop != nil {
|
||||
close(stop)
|
||||
}
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) closeFromKeepAlive() {
|
||||
_ = s.closeTransport(false)
|
||||
}
|
||||
53
keepalive_test.go
Normal file
53
keepalive_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestAutoKeepAliveTimeoutDoesNotDeadlock(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
KeepAliveInterval: 10 * time.Millisecond,
|
||||
KeepAliveTimeout: 20 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
star.setTransport(client, nil)
|
||||
star.startAutoKeepAlive()
|
||||
|
||||
star.keepaliveMu.Lock()
|
||||
done := star.keepaliveDone
|
||||
star.keepaliveMu.Unlock()
|
||||
if done == nil {
|
||||
t.Fatal("keepalive did not start")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("keepalive goroutine did not exit after keepalive timeout")
|
||||
}
|
||||
|
||||
if client := star.snapshotSSHClient(); client != nil {
|
||||
t.Fatal("ssh client should be closed after keepalive timeout")
|
||||
}
|
||||
}
|
||||
362
login.go
Normal file
362
login.go
Normal file
@ -0,0 +1,362 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||
|
||||
var defaultAuthOrder = []AuthMethodKind{
|
||||
AuthMethodSSHAgent,
|
||||
AuthMethodPrivateKey,
|
||||
AuthMethodPassword,
|
||||
AuthMethodKeyboardInteractive,
|
||||
}
|
||||
|
||||
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
return loginWithContext(ctx, info)
|
||||
}
|
||||
|
||||
func Login(info LoginInput) (*StarSSH, error) {
|
||||
return LoginContext(context.Background(), info)
|
||||
}
|
||||
|
||||
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
info = normalizeLoginInput(info)
|
||||
if info.HostKeyCallback == nil {
|
||||
return nil, ErrHostKeyCallbackRequired
|
||||
}
|
||||
|
||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
||||
defer cancel()
|
||||
|
||||
sshInfo := &StarSSH{
|
||||
LoginInfo: info,
|
||||
}
|
||||
|
||||
auth, authCleanup, err := buildAuthMethods(info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authCleanup != nil {
|
||||
defer authCleanup()
|
||||
}
|
||||
|
||||
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
sshInfo.PublicKey = key
|
||||
sshInfo.RemoteAddr = remote
|
||||
sshInfo.Hostname = hostname
|
||||
|
||||
return info.HostKeyCallback(hostname, remote, key)
|
||||
}
|
||||
|
||||
bannerCallback := func(banner string) error {
|
||||
sshInfo.Banner = banner
|
||||
if info.BannerCallback != nil {
|
||||
return info.BannerCallback(banner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: info.User,
|
||||
Auth: auth,
|
||||
Timeout: info.Timeout,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
BannerCallback: bannerCallback,
|
||||
}
|
||||
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
|
||||
clientConfig.Config = ssh.Config{
|
||||
Ciphers: info.Ciphers,
|
||||
MACs: info.MACs,
|
||||
KeyExchanges: info.KeyExchanges,
|
||||
}
|
||||
}
|
||||
|
||||
targetAddr := joinHostPort(info.Addr, info.Port)
|
||||
rawConn, upstream, err := dialTargetConn(loginCtx, info)
|
||||
if err != nil {
|
||||
return sshInfo, err
|
||||
}
|
||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
if upstream != nil {
|
||||
_ = upstream.Close()
|
||||
}
|
||||
return sshInfo, err
|
||||
}
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
|
||||
sshInfo.setTransport(client, upstream)
|
||||
if sshInfo.PublicKey != nil {
|
||||
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
|
||||
}
|
||||
sshInfo.startAutoKeepAlive()
|
||||
|
||||
return sshInfo, nil
|
||||
}
|
||||
|
||||
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if timeout <= 0 {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
|
||||
info := LoginInput{
|
||||
Addr: host,
|
||||
Port: port,
|
||||
Timeout: timeout,
|
||||
User: user,
|
||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||
}
|
||||
|
||||
if prikeyPath != "" {
|
||||
prikey, err := os.ReadFile(prikeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Prikey = string(prikey)
|
||||
if passwd != "" {
|
||||
info.PrikeyPwd = passwd
|
||||
}
|
||||
} else {
|
||||
info.Password = passwd
|
||||
}
|
||||
|
||||
return Login(info)
|
||||
}
|
||||
|
||||
func normalizeLoginInput(info LoginInput) LoginInput {
|
||||
if info.Port <= 0 {
|
||||
info.Port = defaultSSHPort
|
||||
}
|
||||
if info.Timeout <= 0 {
|
||||
info.Timeout = defaultLoginTimeout
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
auth := make([]ssh.AuthMethod, 0, len(order))
|
||||
var agentErr error
|
||||
var cleanupFuncs []func()
|
||||
|
||||
for _, methodKind := range order {
|
||||
switch methodKind {
|
||||
case AuthMethodPrivateKey:
|
||||
method, err := buildPrivateKeyAuthMethod(info)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodPassword:
|
||||
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodKeyboardInteractive:
|
||||
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodSSHAgent:
|
||||
if info.DisableSSHAgent {
|
||||
continue
|
||||
}
|
||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
||||
if err != nil {
|
||||
agentErr = err
|
||||
continue
|
||||
}
|
||||
if agentMethod != nil {
|
||||
auth = append(auth, agentMethod)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanupFuncs = append(cleanupFuncs, cleanup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(auth) == 0 {
|
||||
if agentErr != nil {
|
||||
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
|
||||
}
|
||||
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
|
||||
}
|
||||
|
||||
return auth, composeCleanup(cleanupFuncs...), nil
|
||||
}
|
||||
|
||||
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
|
||||
if len(order) == 0 {
|
||||
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
|
||||
}
|
||||
|
||||
normalized := make([]AuthMethodKind, 0, len(order))
|
||||
seen := make(map[AuthMethodKind]struct{}, len(order))
|
||||
for _, raw := range order {
|
||||
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
|
||||
if kind == "" {
|
||||
return nil, errors.New("auth order contains an empty auth method")
|
||||
}
|
||||
if !isSupportedAuthMethodKind(kind) {
|
||||
return nil, fmt.Errorf("unsupported auth method %q", raw)
|
||||
}
|
||||
if _, exists := seen[kind]; exists {
|
||||
continue
|
||||
}
|
||||
seen[kind] = struct{}{}
|
||||
normalized = append(normalized, kind)
|
||||
}
|
||||
|
||||
if len(normalized) == 0 {
|
||||
return nil, errors.New("auth order is empty")
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
|
||||
switch kind {
|
||||
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
|
||||
if strings.TrimSpace(info.Prikey) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pemBytes := []byte(info.Prikey)
|
||||
if info.PrikeyPwd == "" {
|
||||
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
|
||||
if password != "" {
|
||||
return ssh.Password(password)
|
||||
}
|
||||
if callback != nil {
|
||||
return ssh.PasswordCallback(callback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKeyboardInteractiveAuthMethod(
|
||||
password string,
|
||||
passwordCallback func() (string, error),
|
||||
challenge ssh.KeyboardInteractiveChallenge,
|
||||
) ssh.AuthMethod {
|
||||
if challenge != nil {
|
||||
return ssh.KeyboardInteractive(challenge)
|
||||
}
|
||||
if password == "" && passwordCallback == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||
if len(questions) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
answer := password
|
||||
if answer == "" {
|
||||
var err error
|
||||
answer, err = passwordCallback()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = answer
|
||||
}
|
||||
return answers, nil
|
||||
}
|
||||
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
|
||||
}
|
||||
|
||||
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||||
conn, err := dialSSHAgent(timeout)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSSHAgentUnavailable) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
signers, err := agent.NewClient(conn).Signers()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(signers) == 0 {
|
||||
_ = conn.Close()
|
||||
return nil, nil, errors.New("ssh-agent has no loaded keys")
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signers...), func() {
|
||||
_ = conn.Close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func composeCleanup(funcs ...func()) func() {
|
||||
if len(funcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return func() {
|
||||
for i := len(funcs) - 1; i >= 0; i-- {
|
||||
if funcs[i] != nil {
|
||||
funcs[i]()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
466
pool.go
Normal file
466
pool.go
Normal file
@ -0,0 +1,466 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultExecPoolMaxOpenConns = 4
|
||||
|
||||
var ErrExecPoolClosed = errors.New("exec pool is closed")
|
||||
|
||||
type ExecPoolConfig struct {
|
||||
Login LoginInput
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
MaxIdleTime time.Duration
|
||||
DisableHealthCheck bool
|
||||
HealthCheckTimeout time.Duration
|
||||
}
|
||||
|
||||
type ExecPoolStats struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
MaxIdleTime time.Duration
|
||||
OpenConns int
|
||||
IdleConns int
|
||||
InUseConns int
|
||||
}
|
||||
|
||||
type ExecPool struct {
|
||||
loginInfo LoginInput
|
||||
maxOpen int
|
||||
maxIdle int
|
||||
maxIdleTime time.Duration
|
||||
|
||||
idle chan *pooledClient
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
healthCheckOnAcquire bool
|
||||
healthCheckTimeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
open int
|
||||
closed bool
|
||||
}
|
||||
|
||||
type pooledClient struct {
|
||||
client *StarSSH
|
||||
idleAt time.Time
|
||||
}
|
||||
|
||||
func NewExecPool(config ExecPoolConfig) *ExecPool {
|
||||
maxOpen := config.MaxOpenConns
|
||||
if maxOpen <= 0 {
|
||||
maxOpen = defaultExecPoolMaxOpenConns
|
||||
}
|
||||
|
||||
maxIdle := config.MaxIdleConns
|
||||
if maxIdle <= 0 || maxIdle > maxOpen {
|
||||
maxIdle = maxOpen
|
||||
}
|
||||
|
||||
return &ExecPool{
|
||||
loginInfo: config.Login,
|
||||
maxOpen: maxOpen,
|
||||
maxIdle: maxIdle,
|
||||
maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime),
|
||||
idle: make(chan *pooledClient, maxIdle),
|
||||
done: make(chan struct{}),
|
||||
healthCheckOnAcquire: !config.DisableHealthCheck,
|
||||
healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
client, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, execErr := client.Exec(ctx, req)
|
||||
if execErr != nil {
|
||||
p.Discard(client)
|
||||
return result, execErr
|
||||
}
|
||||
if releaseErr := p.Release(client); releaseErr != nil {
|
||||
return result, releaseErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) {
|
||||
return p.Exec(ctx, ExecRequest{
|
||||
Command: command,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
|
||||
client, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, execErr := client.ExecStream(ctx, req, onChunk)
|
||||
if execErr != nil {
|
||||
p.Discard(client)
|
||||
return result, execErr
|
||||
}
|
||||
if releaseErr := p.Release(client); releaseErr != nil {
|
||||
return result, releaseErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error {
|
||||
if p == nil {
|
||||
return errors.New("exec pool is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
targetIdle = p.normalizeWarmUpTarget(targetIdle)
|
||||
if targetIdle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idleCount, create, err := p.tryWarmUp(targetIdle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idleCount >= targetIdle || !create {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := LoginContext(ctx, p.loginInfo)
|
||||
if err != nil {
|
||||
p.releaseSlot()
|
||||
return err
|
||||
}
|
||||
if err := p.Release(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) {
|
||||
if p == nil {
|
||||
return nil, errors.New("exec pool is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
for {
|
||||
idleClient, create, err := p.tryAcquire()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if idleClient != nil {
|
||||
client, ok := p.takeIdleClient(ctx, idleClient)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
if create {
|
||||
conn, err := LoginContext(ctx, p.loginInfo)
|
||||
if err != nil {
|
||||
p.releaseSlot()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-p.done:
|
||||
return nil, ErrExecPoolClosed
|
||||
case idleClient = <-p.idle:
|
||||
if idleClient == nil {
|
||||
continue
|
||||
}
|
||||
client, ok := p.takeIdleClient(ctx, idleClient)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Release(client *StarSSH) error {
|
||||
if p == nil {
|
||||
return errors.New("exec pool is nil")
|
||||
}
|
||||
if client == nil {
|
||||
p.releaseSlot()
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
p.closeClient(client)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case p.idle <- &pooledClient{
|
||||
client: client,
|
||||
idleAt: time.Now(),
|
||||
}:
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
default:
|
||||
p.mu.Unlock()
|
||||
p.closeClient(client)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Discard(client *StarSSH) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
if client == nil {
|
||||
p.releaseSlot()
|
||||
return
|
||||
}
|
||||
p.closeClient(client)
|
||||
}
|
||||
|
||||
func (p *ExecPool) Stats() ExecPoolStats {
|
||||
if p == nil {
|
||||
return ExecPoolStats{}
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
idleCount := len(p.idle)
|
||||
openCount := p.open
|
||||
inUseCount := openCount - idleCount
|
||||
if inUseCount < 0 {
|
||||
inUseCount = 0
|
||||
}
|
||||
|
||||
return ExecPoolStats{
|
||||
MaxOpenConns: p.maxOpen,
|
||||
MaxIdleConns: p.maxIdle,
|
||||
MaxIdleTime: p.maxIdleTime,
|
||||
OpenConns: openCount,
|
||||
IdleConns: idleCount,
|
||||
InUseConns: inUseCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Close() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
p.closeOnce.Do(func() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
idleClients := p.drainIdleLocked()
|
||||
p.mu.Unlock()
|
||||
close(p.done)
|
||||
|
||||
for _, client := range idleClients {
|
||||
if err := client.Close(); err != nil && closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
})
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (p *ExecPool) CloseIdle() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
idleClients := p.drainIdleLocked()
|
||||
p.mu.Unlock()
|
||||
|
||||
var closeErr error
|
||||
for _, client := range idleClients {
|
||||
if err := client.Close(); err != nil && closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil, false, ErrExecPoolClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case client := <-p.idle:
|
||||
return client, false, nil
|
||||
default:
|
||||
}
|
||||
|
||||
if p.open < p.maxOpen {
|
||||
p.open++
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return 0, false, ErrExecPoolClosed
|
||||
}
|
||||
|
||||
idleCount := len(p.idle)
|
||||
if idleCount >= targetIdle {
|
||||
return idleCount, false, nil
|
||||
}
|
||||
if p.open >= p.maxOpen {
|
||||
return idleCount, false, nil
|
||||
}
|
||||
|
||||
p.open++
|
||||
return idleCount, true, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int {
|
||||
if p == nil || p.maxIdle <= 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIdle <= 0 {
|
||||
return p.maxIdle
|
||||
}
|
||||
if targetIdle > p.maxIdle {
|
||||
return p.maxIdle
|
||||
}
|
||||
return targetIdle
|
||||
}
|
||||
|
||||
func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) {
|
||||
if idleClient == nil {
|
||||
return nil, false
|
||||
}
|
||||
if idleClient.client == nil {
|
||||
p.releaseSlot()
|
||||
return nil, false
|
||||
}
|
||||
if p.isIdleExpired(idleClient) {
|
||||
p.closePooledClient(idleClient)
|
||||
return nil, false
|
||||
}
|
||||
if err := p.healthCheckClient(ctx, idleClient.client); err != nil {
|
||||
p.closePooledClient(idleClient)
|
||||
return nil, false
|
||||
}
|
||||
return idleClient.client, true
|
||||
}
|
||||
|
||||
func (p *ExecPool) isIdleExpired(client *pooledClient) bool {
|
||||
if p == nil || client == nil || client.client == nil {
|
||||
return false
|
||||
}
|
||||
if p.maxIdleTime <= 0 || client.idleAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Since(client.idleAt) >= p.maxIdleTime
|
||||
}
|
||||
|
||||
func (p *ExecPool) drainIdleLocked() []*StarSSH {
|
||||
clients := make([]*StarSSH, 0, len(p.idle))
|
||||
for {
|
||||
select {
|
||||
case idleClient := <-p.idle:
|
||||
if p.open > 0 {
|
||||
p.open--
|
||||
}
|
||||
if idleClient == nil || idleClient.client == nil {
|
||||
continue
|
||||
}
|
||||
clients = append(clients, idleClient.client)
|
||||
default:
|
||||
return clients
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) releaseSlot() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.open > 0 {
|
||||
p.open--
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) closeClient(client *StarSSH) {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
p.releaseSlot()
|
||||
}
|
||||
|
||||
func (p *ExecPool) closePooledClient(client *pooledClient) {
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
if client.client == nil {
|
||||
p.releaseSlot()
|
||||
return
|
||||
}
|
||||
p.closeClient(client.client)
|
||||
}
|
||||
|
||||
func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error {
|
||||
if client == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
}
|
||||
if !p.healthCheckOnAcquire {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
timeout := p.healthCheckTimeout
|
||||
if timeout > 0 {
|
||||
healthCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ctx = healthCtx
|
||||
}
|
||||
return client.PingContext(ctx)
|
||||
}
|
||||
|
||||
func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout <= 0 {
|
||||
return defaultKeepAliveTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func normalizeMaxIdleTime(timeout time.Duration) time.Duration {
|
||||
if timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
121
session.go
Normal file
121
session.go
Normal file
@ -0,0 +1,121 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (s *StarSSH) Close() error {
|
||||
return s.closeTransport(true)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewSession() (*ssh.Session, error) {
|
||||
return s.NewPTYSession(nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewExecSession(client)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewPTYSession(client, config)
|
||||
}
|
||||
|
||||
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
return NewExecSession(client)
|
||||
}
|
||||
|
||||
func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
return client.NewSession()
|
||||
}
|
||||
|
||||
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
return NewPTYSession(client, nil)
|
||||
}
|
||||
|
||||
func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := normalizeTerminalConfig(config)
|
||||
if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func normalizeTerminalConfig(config *TerminalConfig) TerminalConfig {
|
||||
cfg := TerminalConfig{
|
||||
Term: defaultPTYTerm,
|
||||
Rows: defaultPTYRows,
|
||||
Columns: defaultPTYColumns,
|
||||
Modes: ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
},
|
||||
}
|
||||
if config == nil {
|
||||
return cfg
|
||||
}
|
||||
|
||||
if strings.TrimSpace(config.Term) != "" {
|
||||
cfg.Term = config.Term
|
||||
}
|
||||
if config.Rows > 0 {
|
||||
cfg.Rows = config.Rows
|
||||
}
|
||||
if config.Columns > 0 {
|
||||
cfg.Columns = config.Columns
|
||||
}
|
||||
if len(config.Modes) > 0 {
|
||||
cfg.Modes = config.Modes
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func normalizeAlreadyClosedError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func closeUpstream(upstream *StarSSH) error {
|
||||
if upstream == nil {
|
||||
return nil
|
||||
}
|
||||
return upstream.Close()
|
||||
}
|
||||
475
sftp_test.go
Normal file
475
sftp_test.go
Normal file
@ -0,0 +1,475 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
)
|
||||
|
||||
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
||||
opts := normalizeSFTPTransferOptions(nil)
|
||||
if !opts.AtomicUpload {
|
||||
t.Fatal("expected atomic upload to default to enabled")
|
||||
}
|
||||
if !opts.AtomicDownload {
|
||||
t.Fatal("expected atomic download to default to enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
|
||||
verifyErr := errors.New("verify failed")
|
||||
var verifiedPath string
|
||||
oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc
|
||||
sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error {
|
||||
verifiedPath = remotePath
|
||||
return verifyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
||||
})
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
if verifiedPath == remotePath {
|
||||
t.Fatal("expected upload verification to run against temp path before final rename")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote file: %v", err)
|
||||
}
|
||||
if string(data) != "original remote" {
|
||||
t.Fatalf("remote target was replaced on verify failure: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remoteRealPath := filepath.Join(root, "remote-real.txt")
|
||||
remotePath := filepath.Join(root, "remote-link.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil {
|
||||
t.Fatalf("write remote backing file: %v", err)
|
||||
}
|
||||
if err := os.Symlink(remoteRealPath, remotePath); err != nil {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Lstat(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("lstat remote symlink: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Fatal("expected remote target to remain a symlink")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(remoteRealPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote backing file: %v", err)
|
||||
}
|
||||
if string(data) != "original remote" {
|
||||
t.Fatalf("remote backing file changed unexpectedly: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote-dir")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir remote target: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat remote directory: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatal("expected remote target to remain a directory")
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o755)
|
||||
assertFileContent(t, remotePath, "new payload")
|
||||
}
|
||||
|
||||
func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(localPath, 0o751); err != nil {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o751)
|
||||
assertFileContent(t, remotePath, "new payload")
|
||||
}
|
||||
|
||||
func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out bytes: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o755)
|
||||
assertFileContent(t, remotePath, "byte payload")
|
||||
}
|
||||
|
||||
func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
verifyErr := errors.New("verify local failed")
|
||||
var verifiedPath string
|
||||
oldVerifyLocalSize := sftpVerifyLocalSizeFunc
|
||||
sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error {
|
||||
verifiedPath = localPath
|
||||
return verifyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
if verifiedPath == dstPath {
|
||||
t.Fatal("expected download verification to run against temp path before final rename")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read local file: %v", err)
|
||||
}
|
||||
if string(data) != "original local" {
|
||||
t.Fatalf("local target was replaced on verify failure: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
localRealPath := filepath.Join(root, "local-real.txt")
|
||||
dstPath := filepath.Join(root, "local-link.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local backing file: %v", err)
|
||||
}
|
||||
if err := os.Symlink(localRealPath, dstPath); err != nil {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Lstat(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("lstat local symlink: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Fatal("expected local target to remain a symlink")
|
||||
}
|
||||
assertFileContent(t, localRealPath, "original local")
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local-dir")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(dstPath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir local target: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat local directory: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatal("expected local target to remain a directory")
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.sh")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(dstPath, 0o755); err != nil {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, dstPath, 0o755)
|
||||
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
||||
}
|
||||
|
||||
func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.sh")
|
||||
dstPath := filepath.Join(root, "local.sh")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(srcPath, 0o751); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, dstPath, 0o751)
|
||||
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
||||
}
|
||||
|
||||
func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
copyErr := errors.New("copy failed")
|
||||
var copyTargetPath string
|
||||
oldCopy := sftpCopyWithProgressFunc
|
||||
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||
file, ok := dst.(*os.File)
|
||||
if !ok {
|
||||
t.Fatalf("expected local temp file writer, got %T", dst)
|
||||
}
|
||||
copyTargetPath = file.Name()
|
||||
|
||||
buf := make([]byte, 8)
|
||||
n, readErr := src.Read(buf)
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
return 0, readErr
|
||||
}
|
||||
if n > 0 {
|
||||
written, err := dst.Write(buf[:n])
|
||||
if err != nil {
|
||||
return int64(written), err
|
||||
}
|
||||
return int64(written), copyErr
|
||||
}
|
||||
return 0, copyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpCopyWithProgressFunc = oldCopy
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected copy failure, got %v", err)
|
||||
}
|
||||
if copyTargetPath == dstPath {
|
||||
t.Fatal("expected partial download writes to stay on temp path")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read local file: %v", err)
|
||||
}
|
||||
if string(data) != "original local" {
|
||||
t.Fatalf("local target was modified by partial download: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
||||
t.Helper()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
server, err := sftp.NewServer(serverConn)
|
||||
if err != nil {
|
||||
t.Fatalf("create sftp server: %v", err)
|
||||
}
|
||||
|
||||
serveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
serveErrCh <- server.Serve()
|
||||
}()
|
||||
|
||||
client, err := sftp.NewClientPipe(clientConn, clientConn)
|
||||
if err != nil {
|
||||
_ = server.Close()
|
||||
t.Fatalf("create sftp client: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
serveErr := <-serveErrCh
|
||||
if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil {
|
||||
return
|
||||
}
|
||||
t.Errorf("unexpected sftp server error: %v", serveErr)
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func assertNoTransferTemps(t *testing.T, targetPath string) {
|
||||
t.Helper()
|
||||
|
||||
matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*")
|
||||
if err != nil {
|
||||
t.Fatalf("glob temp files: %v", err)
|
||||
}
|
||||
if len(matches) != 0 {
|
||||
t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches)
|
||||
}
|
||||
}
|
||||
|
||||
func assertMode(t *testing.T, targetPath string, want os.FileMode) {
|
||||
t.Helper()
|
||||
|
||||
info, err := os.Stat(targetPath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat %q: %v", targetPath, err)
|
||||
}
|
||||
if got := info.Mode().Perm(); got != want {
|
||||
t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFileContent(t *testing.T, targetPath string, want string) {
|
||||
t.Helper()
|
||||
|
||||
data, err := os.ReadFile(targetPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read %q: %v", targetPath, err)
|
||||
}
|
||||
if string(data) != want {
|
||||
t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want)
|
||||
}
|
||||
}
|
||||
592
shell.go
Normal file
592
shell.go
Normal file
@ -0,0 +1,592 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells")
|
||||
|
||||
type ShellRequest struct {
|
||||
Command string
|
||||
Timeout time.Duration
|
||||
Keyword string
|
||||
UseWaitDefault *bool
|
||||
}
|
||||
|
||||
// NewShell creates the legacy prompt-driven POSIX shell helper.
|
||||
// For raw interactive terminal flows, prefer NewTerminal.
|
||||
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
|
||||
shell = &StarShell{
|
||||
UseWaitDefault: true,
|
||||
WaitTimeout: defaultShellWaitTimeout,
|
||||
isecho: true,
|
||||
iscolor: true,
|
||||
promptToken: defaultShellPromptToken,
|
||||
}
|
||||
|
||||
shell.Session, err = s.NewPTYSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shell.in, err = shell.Session.StdinPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdout, err := shell.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
shell.out = bufio.NewReader(stdout)
|
||||
|
||||
stderr, err := shell.Session.StderrPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
shell.er = bufio.NewReader(stderr)
|
||||
|
||||
if err := shell.Session.Shell(); err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go shell.watchSession()
|
||||
shell.gohub()
|
||||
|
||||
if err := shell.configurePrompt(context.Background()); err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shell.Clear()
|
||||
return shell, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) configurePrompt(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return errors.New("shell is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout)
|
||||
defer cancel()
|
||||
ctx = timeoutCtx
|
||||
}
|
||||
|
||||
prompt := s.promptToken + " "
|
||||
setupCommands := []string{
|
||||
"unset PROMPT_COMMAND >/dev/null 2>&1 || true",
|
||||
fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)),
|
||||
}
|
||||
|
||||
s.Clear()
|
||||
for _, cmd := range setupCommands {
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
|
||||
}
|
||||
}
|
||||
|
||||
probeToken := "__STARSSH_POSIX_READY__" + newNonce(6)
|
||||
if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(defaultShellPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
outRaw, errRaw, runErr := s.readState()
|
||||
if runErr != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr)
|
||||
}
|
||||
|
||||
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
|
||||
if strings.Contains(outs, probeToken) {
|
||||
s.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
|
||||
if looksLikeNonPOSIXShellError(errs) {
|
||||
return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) watchSession() {
|
||||
if s == nil || s.Session == nil {
|
||||
return
|
||||
}
|
||||
if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
||||
s.setError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
s.closeOnce.Do(func() {
|
||||
if s.Session == nil {
|
||||
return
|
||||
}
|
||||
closeErr = s.Session.Close()
|
||||
})
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchNoColor(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.iscolor = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchEcho(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isecho = run
|
||||
}
|
||||
|
||||
func (s *StarShell) TrimColor(str string) string {
|
||||
s.rw.RLock()
|
||||
shouldTrim := s.iscolor
|
||||
s.rw.RUnlock()
|
||||
|
||||
if shouldTrim {
|
||||
return SedColor(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchPrint(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isprint = run
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchFunc(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isfuncs = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SetFunc(funcs func(string)) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.funcs = funcs
|
||||
}
|
||||
|
||||
func (s *StarShell) Clear() {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.outbyte = []byte{}
|
||||
s.errbyte = []byte{}
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
return s.Shell(cmd, sleep)
|
||||
}
|
||||
|
||||
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
|
||||
s.commandMu.Lock()
|
||||
defer s.commandMu.Unlock()
|
||||
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
outRaw, errRaw, runErr := s.GetResult(sleep)
|
||||
if runErr != nil {
|
||||
return "", "", runErr
|
||||
}
|
||||
|
||||
outText := s.TrimColor(strings.TrimSpace(string(outRaw)))
|
||||
|
||||
s.rw.RLock()
|
||||
echoEnabled := s.isecho
|
||||
s.rw.RUnlock()
|
||||
if echoEnabled {
|
||||
outText = stripCommandEchoFromOutput(outText, cmd)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
|
||||
result, err := s.Run(context.Background(), ShellRequest{
|
||||
Command: cmd,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError()
|
||||
}
|
||||
|
||||
func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) {
|
||||
return s.Run(ctx, ShellRequest{
|
||||
Command: command,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("shell is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
s.commandMu.Lock()
|
||||
defer s.commandMu.Unlock()
|
||||
|
||||
if strings.TrimSpace(req.Command) == "" {
|
||||
return nil, errors.New("command is empty")
|
||||
}
|
||||
|
||||
s.rw.RLock()
|
||||
useDefault := s.UseWaitDefault
|
||||
keyword := s.Keyword
|
||||
waitTimeout := s.WaitTimeout
|
||||
promptToken := s.promptToken
|
||||
s.rw.RUnlock()
|
||||
if req.UseWaitDefault != nil {
|
||||
useDefault = *req.UseWaitDefault
|
||||
}
|
||||
if strings.TrimSpace(req.Keyword) != "" {
|
||||
keyword = req.Keyword
|
||||
}
|
||||
if req.Timeout > 0 {
|
||||
waitTimeout = req.Timeout
|
||||
}
|
||||
if !useDefault && keyword == "" {
|
||||
return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set")
|
||||
}
|
||||
if waitTimeout <= 0 {
|
||||
waitTimeout = defaultShellWaitTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
ctx = timeoutCtx
|
||||
}
|
||||
startAt := time.Now()
|
||||
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
|
||||
beginToken, endToken := newCommandTokens()
|
||||
markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken)
|
||||
|
||||
if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.WriteCommand(req.Command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.WriteCommand(markerCmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
outc string
|
||||
errc string
|
||||
exitCode int
|
||||
done bool
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(defaultShellPollInterval):
|
||||
}
|
||||
|
||||
outRaw, errRaw, runErr := s.readState()
|
||||
if runErr != nil {
|
||||
return nil, runErr
|
||||
}
|
||||
|
||||
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
|
||||
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
|
||||
|
||||
s.rw.RLock()
|
||||
useDefault = s.UseWaitDefault
|
||||
keyword = s.Keyword
|
||||
s.rw.RUnlock()
|
||||
|
||||
if useDefault {
|
||||
segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if found {
|
||||
outc = segment
|
||||
errc = errs
|
||||
exitCode = rc
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) {
|
||||
outc = outs
|
||||
errc = errs
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !done {
|
||||
return nil, errors.New("failed to collect shell result")
|
||||
}
|
||||
|
||||
outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken)
|
||||
errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken)
|
||||
outc = stripCommandEchoFromOutput(outc, req.Command)
|
||||
|
||||
stdoutText := strings.TrimSpace(outc)
|
||||
stderrText := strings.TrimSpace(errc)
|
||||
result := &ExecResult{
|
||||
Command: req.Command,
|
||||
Stdout: []byte(stdoutText),
|
||||
Stderr: []byte(stderrText),
|
||||
Combined: combineCommandOutput(stdoutText, stderrText),
|
||||
Duration: time.Since(startAt),
|
||||
}
|
||||
if useDefault {
|
||||
result.ExitCode = exitCode
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) {
|
||||
lines := strings.Split(stdout, "\n")
|
||||
beginLine := -1
|
||||
for i, line := range lines {
|
||||
if strings.TrimSpace(line) == beginToken {
|
||||
beginLine = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if beginLine < 0 {
|
||||
return "", 0, false, nil
|
||||
}
|
||||
|
||||
segment := strings.Join(lines[beginLine+1:], "\n")
|
||||
before, rc, found, err := splitByEndToken(segment, endToken)
|
||||
if err != nil {
|
||||
return "", 0, false, err
|
||||
}
|
||||
if !found {
|
||||
return "", 0, false, nil
|
||||
}
|
||||
return strings.TrimSpace(before), rc, true, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
|
||||
if sleep > 0 {
|
||||
time.Sleep(time.Millisecond * time.Duration(sleep))
|
||||
}
|
||||
return s.readState()
|
||||
}
|
||||
|
||||
func (s *StarShell) WriteCommand(cmd string) error {
|
||||
return s.Write([]byte(cmd + "\n"))
|
||||
}
|
||||
|
||||
func (s *StarShell) Write(bstr []byte) error {
|
||||
if s == nil {
|
||||
return errors.New("shell is nil")
|
||||
}
|
||||
if s.in == nil {
|
||||
return errors.New("shell stdin is not initialized")
|
||||
}
|
||||
if _, _, runErr := s.readState(); runErr != nil {
|
||||
return runErr
|
||||
}
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
_, err := s.in.Write(bstr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarShell) gohub() {
|
||||
if s.er != nil {
|
||||
go s.streamPump(s.er, true)
|
||||
}
|
||||
if s.out != nil {
|
||||
go s.streamPump(s.out, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) {
|
||||
var cache []byte
|
||||
|
||||
for {
|
||||
read, err := reader.ReadByte()
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
s.rw.Lock()
|
||||
if isStderr {
|
||||
s.errbyte = append(s.errbyte, read)
|
||||
} else {
|
||||
s.outbyte = append(s.outbyte, read)
|
||||
}
|
||||
printEnabled := s.isprint
|
||||
funcEnabled := s.isfuncs && s.funcs != nil
|
||||
lineHandler := s.funcs
|
||||
trimColor := s.iscolor
|
||||
s.rw.Unlock()
|
||||
|
||||
if printEnabled {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
|
||||
cache = append(cache, read)
|
||||
if read == '\n' {
|
||||
if funcEnabled {
|
||||
line := strings.TrimSpace(string(cache))
|
||||
if trimColor {
|
||||
line = SedColor(line)
|
||||
}
|
||||
go lineHandler(line)
|
||||
}
|
||||
cache = cache[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) setError(err error) {
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
if s.errors == nil {
|
||||
s.errors = err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) readState() ([]byte, []byte, error) {
|
||||
s.rw.RLock()
|
||||
defer s.rw.RUnlock()
|
||||
|
||||
outCopy := make([]byte, len(s.outbyte))
|
||||
copy(outCopy, s.outbyte)
|
||||
|
||||
errCopy := make([]byte, len(s.errbyte))
|
||||
copy(errCopy, s.errbyte)
|
||||
|
||||
if s.errors != nil {
|
||||
return outCopy, errCopy, s.errors
|
||||
}
|
||||
|
||||
return outCopy, errCopy, nil
|
||||
}
|
||||
|
||||
func stripCommandEchoFromOutput(output string, cmd string) string {
|
||||
if output == "" || cmd == "" {
|
||||
return strings.TrimSpace(output)
|
||||
}
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
cmdLines := strings.Split(cmd, "\n")
|
||||
|
||||
for _, cmdLine := range cmdLines {
|
||||
trimmedCmd := strings.TrimSpace(cmdLine)
|
||||
if trimmedCmd == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, line := range lines {
|
||||
if strings.TrimSpace(line) == trimmedCmd {
|
||||
lines = append(lines[:i], lines[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func looksLikeNonPOSIXShellError(output string) bool {
|
||||
if strings.TrimSpace(output) == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(output)
|
||||
indicators := []string{
|
||||
"is not recognized as an internal or external command",
|
||||
"the term ",
|
||||
"command not found",
|
||||
"unknown command",
|
||||
"not found",
|
||||
"not recognized",
|
||||
}
|
||||
for _, indicator := range indicators {
|
||||
if strings.Contains(lower, indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUid() string {
|
||||
res, _, _ := s.ShellWait("id -u")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetGid() string {
|
||||
res, _, _ := s.ShellWait("id -g")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUser() string {
|
||||
res, _, _ := s.ShellWait("id -un")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetGroup() string {
|
||||
res, _, _ := s.ShellWait("id -gn")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
636
ssh.go
636
ssh.go
@ -1,636 +0,0 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StarSSH struct {
|
||||
Client *ssh.Client
|
||||
PublicKey ssh.PublicKey
|
||||
PubkeyBase64 string
|
||||
Hostname string
|
||||
RemoteAddr net.Addr
|
||||
Banner string
|
||||
LoginInfo LoginInput
|
||||
online bool
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
KeyExchanges []string
|
||||
Ciphers []string
|
||||
MACs []string
|
||||
User string
|
||||
Password string
|
||||
Prikey string
|
||||
PrikeyPwd string
|
||||
Addr string
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
}
|
||||
type StarShell struct {
|
||||
Keyword string
|
||||
UseWaitDefault bool
|
||||
Session *ssh.Session
|
||||
in io.Writer
|
||||
out *bufio.Reader
|
||||
er *bufio.Reader
|
||||
outbyte []byte
|
||||
errbyte []byte
|
||||
lastout int64
|
||||
errors error
|
||||
isprint bool
|
||||
isfuncs bool
|
||||
iscolor bool
|
||||
isecho bool
|
||||
rw sync.RWMutex
|
||||
funcs func(string)
|
||||
}
|
||||
|
||||
func Login(info LoginInput) (*StarSSH, error) {
|
||||
var (
|
||||
auth []ssh.AuthMethod
|
||||
clientConfig *ssh.ClientConfig
|
||||
config ssh.Config
|
||||
err error
|
||||
)
|
||||
sshInfo := new(StarSSH)
|
||||
// get auth method
|
||||
auth = make([]ssh.AuthMethod, 0)
|
||||
if info.Prikey == "" {
|
||||
keyboardInteractiveChallenge := func(
|
||||
user,
|
||||
instruction string,
|
||||
questions []string,
|
||||
echos []bool,
|
||||
) (answers []string, err error) {
|
||||
if len(questions) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
return []string{info.Password}, nil
|
||||
}
|
||||
auth = append(auth, ssh.Password(info.Password))
|
||||
auth = append(auth, ssh.KeyboardInteractive(keyboardInteractiveChallenge))
|
||||
} else {
|
||||
pemBytes := []byte(info.Prikey)
|
||||
var signer ssh.Signer
|
||||
if info.PrikeyPwd == "" {
|
||||
signer, err = ssh.ParsePrivateKey(pemBytes)
|
||||
} else {
|
||||
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.PublicKeys(signer))
|
||||
}
|
||||
|
||||
if len(info.Ciphers) == 0 {
|
||||
config = ssh.Config{
|
||||
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc", "chacha20-poly1305@openssh.com"},
|
||||
}
|
||||
} else {
|
||||
config = ssh.Config{
|
||||
Ciphers: info.Ciphers,
|
||||
}
|
||||
}
|
||||
|
||||
if len(info.MACs) != 0 {
|
||||
config.MACs = info.MACs
|
||||
}
|
||||
if len(info.KeyExchanges) != 0 {
|
||||
config.KeyExchanges = info.KeyExchanges
|
||||
}
|
||||
|
||||
if info.Timeout == 0 {
|
||||
info.Timeout = time.Second * 5
|
||||
}
|
||||
hostKeycbfunc := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
|
||||
sshInfo.PublicKey = key
|
||||
sshInfo.RemoteAddr = remote
|
||||
sshInfo.Hostname = hostname
|
||||
if info.HostKeyCallback != nil {
|
||||
return info.HostKeyCallback(hostname, remote, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bannercbfunc := func(banner string) error {
|
||||
sshInfo.Banner = banner
|
||||
if info.BannerCallback != nil {
|
||||
return info.BannerCallback(banner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
clientConfig = &ssh.ClientConfig{
|
||||
User: info.User,
|
||||
Auth: auth,
|
||||
Timeout: info.Timeout,
|
||||
Config: config,
|
||||
HostKeyCallback: hostKeycbfunc,
|
||||
BannerCallback: bannercbfunc,
|
||||
}
|
||||
|
||||
// connet to ssh
|
||||
|
||||
sshInfo.LoginInfo = info
|
||||
sshInfo.Client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", info.Addr, info.Port), clientConfig)
|
||||
if err == nil && sshInfo.PublicKey != nil {
|
||||
sshInfo.online = true
|
||||
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
|
||||
}
|
||||
return sshInfo, err
|
||||
}
|
||||
|
||||
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
|
||||
var info = LoginInput{
|
||||
Addr: host,
|
||||
Port: port,
|
||||
Timeout: timeout,
|
||||
User: user,
|
||||
}
|
||||
if prikeyPath != "" {
|
||||
prikey, err := ioutil.ReadFile(prikeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Prikey = string(prikey)
|
||||
if passwd != "" {
|
||||
info.PrikeyPwd = passwd
|
||||
}
|
||||
} else {
|
||||
info.Password = passwd
|
||||
}
|
||||
return Login(info)
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
|
||||
var outc, errc string = " ", " "
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
echo := "echo b7Y85R56TUY6R5UTb612"
|
||||
err := s.WriteCommand(cmd)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
err = s.WriteCommand(echo)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 120)
|
||||
outs := string(s.outbyte)
|
||||
errs := string(s.errbyte)
|
||||
outs = strings.TrimSpace(strings.ReplaceAll(outs, "\r\n", "\n"))
|
||||
errs = strings.TrimSpace(strings.ReplaceAll(errs, "\r\n", "\n"))
|
||||
if len(outs) >= len(cmd+"\n"+echo) && outs[0:len(cmd+"\n"+echo)] == cmd+"\n"+echo {
|
||||
outs = outs[len(cmd+"\n"+echo):]
|
||||
} else if len(outs) >= len(cmd) && outs[0:len(cmd)] == cmd {
|
||||
outs = outs[len(cmd):]
|
||||
}
|
||||
if len(errs) >= len(cmd) && errs[0:len(cmd)] == cmd {
|
||||
errs = errs[len(cmd):]
|
||||
}
|
||||
if s.UseWaitDefault {
|
||||
if strings.Index(string(outs), "b7Y85R56TUY6R5UTb612") >= 0 {
|
||||
list := strings.Split(string(outs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
outc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if strings.Index(string(errs), "b7Y85R56TUY6R5UTb612") >= 0 {
|
||||
list := strings.Split(string(errs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
errc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.Keyword != "" {
|
||||
if strings.Index(string(outs), s.Keyword) >= 0 {
|
||||
list := strings.Split(string(outs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
outc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if strings.Index(string(errs), s.Keyword) >= 0 {
|
||||
list := strings.Split(string(errs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
errc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.TrimColor(strings.TrimSpace(outc)), s.TrimColor(strings.TrimSpace(errc)), err
|
||||
}
|
||||
|
||||
func (s *StarShell) Close() error {
|
||||
return s.Session.Close()
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchNoColor(is bool) {
|
||||
s.iscolor = is
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchEcho(is bool) {
|
||||
s.isecho = is
|
||||
}
|
||||
|
||||
func (s *StarShell) TrimColor(str string) string {
|
||||
if s.iscolor {
|
||||
return SedColor(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchPrint(run bool) {
|
||||
s.isprint = run
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchFunc(run bool) {
|
||||
s.isfuncs = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SetFunc(funcs func(string)) {
|
||||
s.funcs = funcs
|
||||
}
|
||||
|
||||
func (s *StarShell) Clear() {
|
||||
defer s.rw.Unlock()
|
||||
s.rw.Lock()
|
||||
s.outbyte = []byte{}
|
||||
s.errbyte = []byte{}
|
||||
time.Sleep(time.Millisecond * 15)
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
|
||||
defer s.Clear()
|
||||
s.Clear()
|
||||
return s.Shell(cmd, sleep)
|
||||
}
|
||||
|
||||
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
tmp1, tmp2, err := s.GetResult(sleep)
|
||||
tmps := s.TrimColor(strings.TrimSpace(string(tmp1)))
|
||||
if s.isecho {
|
||||
n := len(strings.Split(cmd, "\n"))
|
||||
if n == 1 {
|
||||
list := strings.SplitN(tmps, "\n", 2)
|
||||
if len(list) == 2 {
|
||||
tmps = list[1]
|
||||
}
|
||||
} else {
|
||||
list := strings.Split(tmps, "\n")
|
||||
cmds := strings.Split(cmd, "\n")
|
||||
for _, v := range cmds {
|
||||
for k, v2 := range list {
|
||||
if strings.TrimSpace(v2) == strings.TrimSpace(v) {
|
||||
list[k] = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
tmps = ""
|
||||
for _, v := range list {
|
||||
if v != "" {
|
||||
tmps += v + "\n"
|
||||
}
|
||||
}
|
||||
tmps = tmps[0 : len(tmps)-1]
|
||||
}
|
||||
}
|
||||
return tmps, s.TrimColor(strings.TrimSpace(string(tmp2))), err
|
||||
}
|
||||
|
||||
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
|
||||
if s.errors != nil {
|
||||
s.Session.Close()
|
||||
return s.outbyte, s.errbyte, s.errors
|
||||
}
|
||||
if sleep > 0 {
|
||||
time.Sleep(time.Millisecond * time.Duration(sleep))
|
||||
}
|
||||
return s.outbyte, s.errbyte, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) WriteCommand(cmd string) error {
|
||||
return s.Write([]byte(cmd + "\n"))
|
||||
}
|
||||
|
||||
func (s *StarShell) Write(bstr []byte) error {
|
||||
if s.errors != nil {
|
||||
s.Session.Close()
|
||||
return s.errors
|
||||
}
|
||||
_, err := s.in.Write(bstr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarShell) gohub() {
|
||||
go func() {
|
||||
var cache []byte
|
||||
for {
|
||||
read, err := s.er.ReadByte()
|
||||
if err != nil {
|
||||
s.errors = err
|
||||
return
|
||||
}
|
||||
s.errbyte = append(s.errbyte, read)
|
||||
if s.isprint {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
cache = append(cache, read)
|
||||
if read == '\n' {
|
||||
if s.isfuncs {
|
||||
go s.funcs(s.TrimColor(strings.TrimSpace(string(cache))))
|
||||
cache = []byte{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
var cache []byte
|
||||
for {
|
||||
read, err := s.out.ReadByte()
|
||||
if err != nil {
|
||||
s.errors = err
|
||||
return
|
||||
}
|
||||
s.rw.Lock()
|
||||
s.outbyte = append(s.outbyte, read)
|
||||
cache = append(cache, read)
|
||||
s.rw.Unlock()
|
||||
if read == '\n' {
|
||||
if s.isfuncs {
|
||||
go s.funcs(strings.TrimSpace(string(cache)))
|
||||
cache = []byte{}
|
||||
}
|
||||
}
|
||||
if s.isprint {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUid() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=uid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetGid() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=gid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetUser() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetGroup() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
|
||||
shell = new(StarShell)
|
||||
shell.Session, err = s.NewSession()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
shell.in, _ = shell.Session.StdinPipe()
|
||||
tmp, _ := shell.Session.StdoutPipe()
|
||||
shell.out = bufio.NewReader(tmp)
|
||||
tmp, _ = shell.Session.StderrPipe()
|
||||
shell.er = bufio.NewReader(tmp)
|
||||
err = shell.Session.Shell()
|
||||
shell.isecho = true
|
||||
go shell.Session.Wait()
|
||||
shell.UseWaitDefault = true
|
||||
shell.WriteCommand("bash")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
shell.WriteCommand("export PS1= ")
|
||||
shell.WriteCommand("export PS2= ")
|
||||
go shell.gohub()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
shell.Clear()
|
||||
shell.Clear()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *StarSSH) Close() error {
|
||||
if s.online {
|
||||
return s.Client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewSession() (*ssh.Session, error) {
|
||||
return NewSession(s.Client)
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOne(cmd string) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := newsess.CombinedOutput(cmd)
|
||||
newsess.Close()
|
||||
return strings.TrimSpace(string(data)), err
|
||||
}
|
||||
|
||||
func (s *StarSSH) Exists(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -e "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) GetUid() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=uid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetGid() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=gid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetUser() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetGroup() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarSSH) IsFile(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -f "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) IsFolder(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -d "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var bytes, errbytes []byte
|
||||
tmp, _ := newsess.StdoutPipe()
|
||||
reader := bufio.NewReader(tmp)
|
||||
tmp, _ = newsess.StderrPipe()
|
||||
errder := bufio.NewReader(tmp)
|
||||
err = newsess.Start(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c := make(chan int, 1)
|
||||
go newsess.Wait()
|
||||
go func() {
|
||||
for {
|
||||
byt, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Print(string([]byte{byt}))
|
||||
bytes = append(bytes, byt)
|
||||
}
|
||||
c <- 1
|
||||
}()
|
||||
for {
|
||||
byt, err := errder.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Print(string([]byte{byt}))
|
||||
errbytes = append(errbytes, byt)
|
||||
}
|
||||
_ = <-c
|
||||
newsess.Close()
|
||||
if len(errbytes) != 0 {
|
||||
err = errors.New(strings.TrimSpace(string(errbytes)))
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
return strings.TrimSpace(string(bytes)), err
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var bytes, errbytes []byte
|
||||
tmp, _ := newsess.StdoutPipe()
|
||||
reader := bufio.NewReader(tmp)
|
||||
tmp, _ = newsess.StderrPipe()
|
||||
errder := bufio.NewReader(tmp)
|
||||
err = newsess.Start(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c := make(chan int, 1)
|
||||
go newsess.Wait()
|
||||
go func() {
|
||||
for {
|
||||
byt, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
callback(string([]byte{byt}))
|
||||
bytes = append(bytes, byt)
|
||||
}
|
||||
c <- 1
|
||||
}()
|
||||
for {
|
||||
byt, err := errder.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
callback(string([]byte{byt}))
|
||||
errbytes = append(errbytes, byt)
|
||||
}
|
||||
_ = <-c
|
||||
newsess.Close()
|
||||
if len(errbytes) != 0 {
|
||||
err = errors.New(strings.TrimSpace(string(errbytes)))
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
return strings.TrimSpace(string(bytes)), err
|
||||
}
|
||||
|
||||
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
session, err := client.NewSession()
|
||||
return session, err
|
||||
}
|
||||
|
||||
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
var session *ssh.Session
|
||||
var err error
|
||||
// create session
|
||||
if session, err = client.NewSession(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1, // 还是要强制开启
|
||||
//ssh.IGNCR: 0,
|
||||
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
||||
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm", 500, 250, modes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func SedColor(str string) string {
|
||||
reg := regexp.MustCompile(`\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]`)
|
||||
//fmt.Println("regexp:", reg.Match([]byte(str)))
|
||||
return string(reg.ReplaceAll([]byte(str), []byte("")))
|
||||
}
|
||||
21
sshagent_unix.go
Normal file
21
sshagent_unix.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build !windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
|
||||
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
||||
if agentSock == "" {
|
||||
return nil, errSSHAgentUnavailable
|
||||
}
|
||||
if timeout > 0 {
|
||||
return net.DialTimeout("unix", agentSock, timeout)
|
||||
}
|
||||
return net.Dial("unix", agentSock)
|
||||
}
|
||||
70
sshagent_windows.go
Normal file
70
sshagent_windows.go
Normal file
@ -0,0 +1,70 @@
|
||||
//go:build windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
|
||||
|
||||
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
|
||||
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
||||
if agentSock != "" {
|
||||
return dialWindowsSSHAgentEndpoint(agentSock, timeout)
|
||||
}
|
||||
return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true)
|
||||
}
|
||||
|
||||
func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) {
|
||||
if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
|
||||
return dialWindowsNamedPipe(pipePath, timeout, false)
|
||||
}
|
||||
if timeout > 0 {
|
||||
return net.DialTimeout("unix", endpoint, timeout)
|
||||
}
|
||||
return net.Dial("unix", endpoint)
|
||||
}
|
||||
|
||||
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
if timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
conn, err := winio.DialPipeContext(ctx, path)
|
||||
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
|
||||
return nil, errSSHAgentUnavailable
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(endpoint)
|
||||
if trimmed == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalized := trimmed
|
||||
if strings.HasPrefix(normalized, "//./pipe/") {
|
||||
normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/")
|
||||
}
|
||||
if strings.HasPrefix(normalized, `\\.\pipe\`) {
|
||||
return normalized, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isWindowsPipeUnavailable(err error) bool {
|
||||
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
|
||||
}
|
||||
106
state.go
Normal file
106
state.go
Normal file
@ -0,0 +1,106 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshClientRequester interface {
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
var closeSSHClient = func(client sshClientRequester) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return client.Close()
|
||||
}
|
||||
|
||||
func (s *StarSSH) snapshotSSHClient() *ssh.Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
|
||||
client := s.snapshotSSHClient()
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
s.Client = client
|
||||
s.upstream = upstream
|
||||
s.online = client != nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
|
||||
client := s.Client
|
||||
upstream := s.upstream
|
||||
s.Client = nil
|
||||
s.upstream = nil
|
||||
s.online = false
|
||||
return client, upstream
|
||||
}
|
||||
|
||||
func (s *StarSSH) takeKeepaliveHandles() (chan struct{}, chan struct{}) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.keepaliveMu.Lock()
|
||||
defer s.keepaliveMu.Unlock()
|
||||
|
||||
stop := s.keepaliveStop
|
||||
done := s.keepaliveDone
|
||||
s.keepaliveStop = nil
|
||||
s.keepaliveDone = nil
|
||||
return stop, done
|
||||
}
|
||||
|
||||
func (s *StarSSH) closeTransport(waitKeepalive bool) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = s.closeReusableSFTPClient()
|
||||
|
||||
client, upstream := s.detachTransport()
|
||||
stop, done := s.takeKeepaliveHandles()
|
||||
if stop != nil {
|
||||
close(stop)
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
if client != nil {
|
||||
closeErr = normalizeAlreadyClosedError(closeSSHClient(client))
|
||||
}
|
||||
if waitKeepalive && done != nil {
|
||||
<-done
|
||||
}
|
||||
if upstreamErr := closeUpstream(upstream); closeErr == nil {
|
||||
closeErr = upstreamErr
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
431
terminal.go
Normal file
431
terminal.go
Normal file
@ -0,0 +1,431 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) {
|
||||
session, err := s.NewPTYSession(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TerminalSession{
|
||||
Session: session,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
runDone: make(chan struct{}),
|
||||
waitDone: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.attachMu.Lock()
|
||||
defer t.attachMu.Unlock()
|
||||
t.in = stdin
|
||||
t.out = stdout
|
||||
t.errOut = stderr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StdinWriter() io.Writer {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stdin
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StdoutReader() io.Reader {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stdout
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StderrReader() io.Reader {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stderr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Write(data []byte) (int, error) {
|
||||
if t == nil || t.stdin == nil {
|
||||
return 0, errors.New("terminal stdin is not initialized")
|
||||
}
|
||||
return t.stdin.Write(data)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) SendControl(control TerminalControl) error {
|
||||
_, err := t.Write([]byte{byte(control)})
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Interrupt() error {
|
||||
return t.SendControl(TerminalControlInterrupt)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Signal(sig ssh.Signal) error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
if sig == "" {
|
||||
return errors.New("signal is empty")
|
||||
}
|
||||
return t.Session.Signal(sig)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Run(ctx context.Context) error {
|
||||
if t == nil {
|
||||
return errors.New("terminal session is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
t.runOnce.Do(func() {
|
||||
t.runErr = t.run(ctx)
|
||||
close(t.runDone)
|
||||
})
|
||||
|
||||
<-t.runDone
|
||||
return t.runErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) run(ctx context.Context) error {
|
||||
if t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
|
||||
t.attachMu.RLock()
|
||||
in := t.in
|
||||
out := t.out
|
||||
errOut := t.errOut
|
||||
t.attachMu.RUnlock()
|
||||
if out == nil {
|
||||
out = io.Discard
|
||||
}
|
||||
if errOut == nil {
|
||||
errOut = out
|
||||
}
|
||||
|
||||
inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancelInput()
|
||||
|
||||
var copyWG sync.WaitGroup
|
||||
doneCopy := make(chan struct{})
|
||||
copyWG.Add(2)
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if t.stdout != nil {
|
||||
_, _ = io.Copy(out, t.stdout)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if t.stderr != nil {
|
||||
_, _ = io.Copy(errOut, t.stderr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
copyWG.Wait()
|
||||
close(doneCopy)
|
||||
}()
|
||||
|
||||
var doneInput chan struct{}
|
||||
if inputReader != nil && t.stdin != nil {
|
||||
doneInput = make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneInput)
|
||||
_, _ = io.Copy(t.stdin, inputReader)
|
||||
_ = t.stdin.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
waitInputPump := func() {
|
||||
if doneInput == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-doneInput:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if inputCancelable {
|
||||
cancelInput()
|
||||
<-doneInput
|
||||
}
|
||||
}
|
||||
|
||||
type waitResult struct {
|
||||
info TerminalExitInfo
|
||||
err error
|
||||
}
|
||||
|
||||
waitCh := make(chan waitResult, 1)
|
||||
go func() {
|
||||
info, err := t.WaitResult()
|
||||
waitCh <- waitResult{info: info, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-waitCh:
|
||||
waitInputPump()
|
||||
<-doneCopy
|
||||
if result.err != nil {
|
||||
return result.err
|
||||
}
|
||||
return result.info.CommandError()
|
||||
case <-ctx.Done():
|
||||
t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err())
|
||||
cancelInput()
|
||||
_ = t.Close()
|
||||
<-waitCh
|
||||
waitInputPump()
|
||||
<-doneCopy
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Wait() error {
|
||||
info, err := t.WaitResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return info.CommandError()
|
||||
}
|
||||
|
||||
func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) {
|
||||
waitErr := t.waitRaw()
|
||||
info, closeErr := t.snapshotExitState()
|
||||
if closeErr != nil {
|
||||
return info, closeErr
|
||||
}
|
||||
if waitErr == nil {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(waitErr, &exitErr) {
|
||||
return info, nil
|
||||
}
|
||||
if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed {
|
||||
return info, nil
|
||||
}
|
||||
return info, waitErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) ExitInfo() TerminalExitInfo {
|
||||
if t == nil {
|
||||
return TerminalExitInfo{}
|
||||
}
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.exitInfo
|
||||
}
|
||||
|
||||
func (info TerminalExitInfo) Success() bool {
|
||||
return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == ""
|
||||
}
|
||||
|
||||
func (info TerminalExitInfo) CommandError() error {
|
||||
if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal {
|
||||
return nil
|
||||
}
|
||||
if info.ExitCode == 0 && info.ExitSignal == "" {
|
||||
return nil
|
||||
}
|
||||
return &ExecExitError{
|
||||
Status: info.ExitCode,
|
||||
Signal: info.ExitSignal,
|
||||
Message: info.ExitMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Resize(columns int, rows int) error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
if columns <= 0 || rows <= 0 {
|
||||
return errors.New("columns and rows must be > 0")
|
||||
}
|
||||
return t.Session.WindowChange(rows, columns)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Close() error {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
t.closeOnce.Do(func() {
|
||||
if t.stdin != nil {
|
||||
_ = t.stdin.Close()
|
||||
}
|
||||
if t.Session != nil {
|
||||
closeErr = normalizeAlreadyClosedError(t.Session.Close())
|
||||
}
|
||||
})
|
||||
if closeErr != nil {
|
||||
t.markCloseReason(TerminalCloseReasonTransportError, closeErr)
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) waitRaw() error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
|
||||
t.waitOnce.Do(func() {
|
||||
go func() {
|
||||
waitErr := t.Session.Wait()
|
||||
t.setWaitResult(waitErr)
|
||||
close(t.waitDone)
|
||||
}()
|
||||
})
|
||||
|
||||
<-t.waitDone
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.waitErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) setWaitResult(waitErr error) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.stateMu.Lock()
|
||||
defer t.stateMu.Unlock()
|
||||
t.waitErr = waitErr
|
||||
t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) {
|
||||
if t == nil || reason == TerminalCloseReasonUnknown {
|
||||
return
|
||||
}
|
||||
|
||||
t.stateMu.Lock()
|
||||
defer t.stateMu.Unlock()
|
||||
|
||||
if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) {
|
||||
t.closeReason = reason
|
||||
}
|
||||
if err != nil && t.closeErr == nil {
|
||||
t.closeErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) {
|
||||
if t == nil {
|
||||
return TerminalExitInfo{}, nil
|
||||
}
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.exitInfo, t.closeErr
|
||||
}
|
||||
|
||||
func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo {
|
||||
info := TerminalExitInfo{}
|
||||
|
||||
if waitErr == nil {
|
||||
info.Reason = TerminalCloseReasonExit
|
||||
} else {
|
||||
var exitErr *ssh.ExitError
|
||||
switch {
|
||||
case errors.As(waitErr, &exitErr):
|
||||
info.ExitCode = exitErr.ExitStatus()
|
||||
info.ExitSignal = exitErr.Signal()
|
||||
info.ExitMessage = exitErr.Msg()
|
||||
if info.ExitSignal != "" {
|
||||
info.Reason = TerminalCloseReasonSignal
|
||||
} else {
|
||||
info.Reason = TerminalCloseReasonExit
|
||||
}
|
||||
case normalizeAlreadyClosedError(waitErr) == nil:
|
||||
info.Reason = TerminalCloseReasonClosed
|
||||
case errors.Is(waitErr, context.Canceled):
|
||||
info.Reason = TerminalCloseReasonContextCanceled
|
||||
case errors.Is(waitErr, context.DeadlineExceeded):
|
||||
info.Reason = TerminalCloseReasonDeadlineExceeded
|
||||
default:
|
||||
info.Reason = TerminalCloseReasonTransportError
|
||||
}
|
||||
}
|
||||
|
||||
if overrideReason != TerminalCloseReasonUnknown &&
|
||||
terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) {
|
||||
info.Reason = overrideReason
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func terminalCloseReasonFromErr(err error) TerminalCloseReason {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return TerminalCloseReasonDeadlineExceeded
|
||||
case errors.Is(err, context.Canceled):
|
||||
return TerminalCloseReasonContextCanceled
|
||||
case err != nil:
|
||||
return TerminalCloseReasonTransportError
|
||||
default:
|
||||
return TerminalCloseReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func terminalCloseReasonPriority(reason TerminalCloseReason) int {
|
||||
switch reason {
|
||||
case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded:
|
||||
return 30
|
||||
case TerminalCloseReasonTransportError:
|
||||
return 20
|
||||
case TerminalCloseReasonClosed:
|
||||
return 10
|
||||
case TerminalCloseReasonSignal, TerminalCloseReasonExit:
|
||||
return 5
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
225
terminal_input.go
Normal file
225
terminal_input.go
Normal file
@ -0,0 +1,225 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
|
||||
// Implementations that buffer data should return a source that already includes any prefetched bytes.
|
||||
type TerminalInputSourceProvider interface {
|
||||
TerminalInputSource() io.Reader
|
||||
}
|
||||
|
||||
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
|
||||
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
|
||||
type TerminalInputCanceler interface {
|
||||
TerminalInputCancel() error
|
||||
}
|
||||
|
||||
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
|
||||
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
|
||||
type TerminalInputAdapter struct {
|
||||
Reader io.Reader
|
||||
Source io.Reader
|
||||
Cancel func() error
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
|
||||
if a.Reader == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return a.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
|
||||
if a.Source != nil {
|
||||
return a.Source
|
||||
}
|
||||
return a.Reader
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) TerminalInputCancel() error {
|
||||
if a.Cancel != nil {
|
||||
return a.Cancel()
|
||||
}
|
||||
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
|
||||
if in == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
var cancelOnce sync.Once
|
||||
wrapCancel := func(fn func()) func() {
|
||||
return func() {
|
||||
cancelOnce.Do(fn)
|
||||
}
|
||||
}
|
||||
|
||||
if provider, ok := in.(TerminalInputSourceProvider); ok {
|
||||
source := provider.TerminalInputSource()
|
||||
if source == nil || sameReader(source, in) {
|
||||
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||||
return prepared, wrapCancel(func() {
|
||||
cancel()
|
||||
_ = canceler.TerminalInputCancel()
|
||||
}), true, nil
|
||||
}
|
||||
return prepared, cancel, cancelable, nil
|
||||
}
|
||||
|
||||
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||||
}
|
||||
|
||||
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
|
||||
if in == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
switch typed := in.(type) {
|
||||
case *bufio.Reader:
|
||||
return prepareBufferedTerminalInputReader(typed)
|
||||
case *bufio.ReadWriter:
|
||||
if typed.Reader == nil {
|
||||
return in, func() {}, false, nil
|
||||
}
|
||||
return prepareBufferedTerminalInputReader(typed.Reader)
|
||||
}
|
||||
|
||||
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||||
return in, wrapCancel(func() {
|
||||
_ = canceler.TerminalInputCancel()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
if file, ok := in.(*os.File); ok {
|
||||
dup, err := duplicateTerminalInputFile(file)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
|
||||
}
|
||||
return dup, wrapCancel(func() {
|
||||
_ = dup.Close()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
if closer, ok := in.(io.ReadCloser); ok {
|
||||
return closer, wrapCancel(func() {
|
||||
_ = closer.Close()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
return in, func() {}, false, nil
|
||||
}
|
||||
|
||||
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
|
||||
if reader == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
bufferedPrefix, err := snapshotBufferedPrefix(reader)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
underlying := unwrapBufioReader(reader)
|
||||
if underlying == nil {
|
||||
if len(bufferedPrefix) == 0 {
|
||||
return reader, func() {}, false, nil
|
||||
}
|
||||
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
if len(bufferedPrefix) == 0 {
|
||||
return prepared, cancel, cancelable, nil
|
||||
}
|
||||
if prepared == nil {
|
||||
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
|
||||
}
|
||||
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
|
||||
}
|
||||
|
||||
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buffered := reader.Buffered()
|
||||
if buffered == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chunk, err := reader.Peek(buffered)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
|
||||
}
|
||||
prefix := append([]byte(nil), chunk...)
|
||||
if _, err := reader.Discard(len(prefix)); err != nil {
|
||||
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
|
||||
}
|
||||
return prefix, nil
|
||||
}
|
||||
|
||||
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
|
||||
if reader == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(reader)
|
||||
if value.Kind() != reflect.Pointer || value.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
field := value.Elem().FieldByName("rd")
|
||||
if !field.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
|
||||
underlying, ok := underlyingValue.Interface().(io.Reader)
|
||||
if !ok || underlying == nil || sameReader(underlying, reader) {
|
||||
return nil
|
||||
}
|
||||
return underlying
|
||||
}
|
||||
|
||||
func sameReader(left io.Reader, right io.Reader) bool {
|
||||
if left == nil || right == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(left)
|
||||
rightValue := reflect.ValueOf(right)
|
||||
if !leftValue.IsValid() || !rightValue.IsValid() {
|
||||
return false
|
||||
}
|
||||
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
|
||||
return false
|
||||
}
|
||||
return leftValue.Pointer() == rightValue.Pointer()
|
||||
}
|
||||
49
terminal_input_adapter_test.go
Normal file
49
terminal_input_adapter_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPrepareTerminalInputReaderAdapterCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
adapter := TerminalInputAdapter{
|
||||
Reader: bufio.NewReader(reader),
|
||||
Source: reader,
|
||||
Cancel: func() error {
|
||||
return reader.CloseWithError(io.ErrClosedPipe)
|
||||
},
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(adapter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare adapter input: %v", err)
|
||||
}
|
||||
if !cancelable {
|
||||
t.Fatal("expected adapter-backed input to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected adapter cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking adapter input did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
290
terminal_input_test.go
Normal file
290
terminal_input_test.go
Normal file
@ -0,0 +1,290 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type terminalInputProvider struct {
|
||||
io.Reader
|
||||
source io.Reader
|
||||
}
|
||||
|
||||
func (p terminalInputProvider) TerminalInputSource() io.Reader {
|
||||
return p.source
|
||||
}
|
||||
|
||||
type prefixedReadCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderPreservesBufferedBytes(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("hello world")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
buffered := bufio.NewReaderSize(reader, 4)
|
||||
peeked, err := buffered.Peek(5)
|
||||
if err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
if string(peeked) != "hello" {
|
||||
t.Fatalf("unexpected buffered prefix: %q", string(peeked))
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected cancelable reader")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read prepared input: %v", err)
|
||||
}
|
||||
if string(data) != "hello world" {
|
||||
t.Fatalf("unexpected prepared input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReadWriterPreservesBufferedBytes(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("buffered payload")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
readWriter := bufio.NewReadWriter(bufio.NewReaderSize(reader, 8), bufio.NewWriter(io.Discard))
|
||||
if _, err := readWriter.Reader.Peek(8); err != nil {
|
||||
t.Fatalf("prime readwriter buffer: %v", err)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare readwriter input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected cancelable readwriter input")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read prepared readwriter input: %v", err)
|
||||
}
|
||||
if string(data) != "buffered payload" {
|
||||
t.Fatalf("unexpected prepared readwriter input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderFallbackKeepsData(t *testing.T) {
|
||||
buffered := bufio.NewReader(bytes.NewBufferString("abc123"))
|
||||
if _, err := buffered.Peek(3); err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare fallback input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if cancelable {
|
||||
t.Fatal("expected non-cancelable reader")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read fallback input: %v", err)
|
||||
}
|
||||
if string(data) != "abc123" {
|
||||
t.Fatalf("unexpected fallback input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderProviderPrefersExplicitSource(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("provider data")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
prefix, err := buffered.Peek(len("provider data"))
|
||||
if err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
|
||||
source := prefixedReadCloser{
|
||||
Reader: io.MultiReader(bytes.NewReader(append([]byte(nil), prefix...)), reader),
|
||||
Closer: reader,
|
||||
}
|
||||
provider := terminalInputProvider{
|
||||
Reader: buffered,
|
||||
source: source,
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare provider input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected provider-backed input to be cancelable")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read provider input: %v", err)
|
||||
}
|
||||
if string(data) != "provider data" {
|
||||
t.Fatalf("unexpected provider input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderProviderCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
provider := terminalInputProvider{
|
||||
Reader: buffered,
|
||||
source: reader,
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected provider-backed input to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking read did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected bufio reader to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking bufio reader did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReadWriterCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
readWriter := bufio.NewReadWriter(bufio.NewReader(reader), bufio.NewWriter(io.Discard))
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare readwriter input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected bufio readwriter to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking readwriter input did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
21
terminal_input_unix.go
Normal file
21
terminal_input_unix.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build !windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
|
||||
if file == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
fd, err := syscall.Dup(int(file.Fd()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
syscall.CloseOnExec(fd)
|
||||
return os.NewFile(uintptr(fd), file.Name()), nil
|
||||
}
|
||||
32
terminal_input_windows.go
Normal file
32
terminal_input_windows.go
Normal file
@ -0,0 +1,32 @@
|
||||
//go:build windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
|
||||
if file == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
currentProcess := windows.CurrentProcess()
|
||||
var duplicated windows.Handle
|
||||
err := windows.DuplicateHandle(
|
||||
currentProcess,
|
||||
windows.Handle(file.Fd()),
|
||||
currentProcess,
|
||||
&duplicated,
|
||||
0,
|
||||
false,
|
||||
windows.DUPLICATE_SAME_ACCESS,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(duplicated), file.Name()), nil
|
||||
}
|
||||
336
transport.go
Normal file
336
transport.go
Normal file
@ -0,0 +1,336 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c == nil || c.reader == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func resolveDialContext(info LoginInput) DialContextFunc {
|
||||
if info.DialContext != nil {
|
||||
return info.DialContext
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: info.Timeout,
|
||||
}
|
||||
return dialer.DialContext
|
||||
}
|
||||
|
||||
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
|
||||
targetAddr := joinHostPort(info.Addr, info.Port)
|
||||
if info.Jump != nil {
|
||||
return dialViaJump(ctx, info, targetAddr)
|
||||
}
|
||||
|
||||
dialContext := resolveDialContext(info)
|
||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
||||
if proxyConfig != nil {
|
||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||
}
|
||||
|
||||
conn, err := dialContext(ctx, "tcp", targetAddr)
|
||||
return conn, nil, err
|
||||
}
|
||||
|
||||
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
|
||||
if info.Jump == nil {
|
||||
return nil, nil, errors.New("jump login info is nil")
|
||||
}
|
||||
|
||||
jumpClient, err := loginWithContext(ctx, *info.Jump)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
|
||||
if err != nil {
|
||||
_ = jumpClient.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, jumpClient, nil
|
||||
}
|
||||
|
||||
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
|
||||
if dialContext == nil {
|
||||
return nil, nil, errors.New("dial context is nil")
|
||||
}
|
||||
if strings.TrimSpace(proxy.Addr) == "" {
|
||||
return nil, nil, errors.New("proxy address is empty")
|
||||
}
|
||||
|
||||
switch proxy.Type {
|
||||
case ProxyTypeSOCKS5:
|
||||
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
|
||||
return conn, nil, err
|
||||
case ProxyTypeHTTPConnect:
|
||||
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
|
||||
return conn, nil, err
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||||
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
|
||||
if proxy.Username != "" || proxy.Password != "" {
|
||||
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
|
||||
request += "Proxy-Authorization: Basic " + token + "\r\n"
|
||||
}
|
||||
request += "\r\n"
|
||||
|
||||
if _, err := io.WriteString(conn, request); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
|
||||
}
|
||||
|
||||
if reader.Buffered() == 0 {
|
||||
return conn, nil
|
||||
}
|
||||
return &bufferedConn{Conn: conn, reader: reader}, nil
|
||||
}
|
||||
|
||||
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||||
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
methods := []byte{0x00}
|
||||
useAuth := proxy.Username != "" || proxy.Password != ""
|
||||
if useAuth {
|
||||
methods = append(methods, 0x02)
|
||||
}
|
||||
|
||||
hello := append([]byte{0x05, byte(len(methods))}, methods...)
|
||||
if _, err := conn.Write(hello); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, response); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if response[0] != 0x05 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
|
||||
}
|
||||
if response[1] == 0xFF {
|
||||
_ = conn.Close()
|
||||
return nil, errors.New("socks5 proxy has no acceptable auth method")
|
||||
}
|
||||
|
||||
if response[1] == 0x02 {
|
||||
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := readSOCKS5ConnectResponse(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
|
||||
if len(username) > 255 || len(password) > 255 {
|
||||
return errors.New("socks5 username/password too long")
|
||||
}
|
||||
|
||||
request := make([]byte, 0, 3+len(username)+len(password))
|
||||
request = append(request, 0x01, byte(len(username)))
|
||||
request = append(request, []byte(username)...)
|
||||
request = append(request, byte(len(password)))
|
||||
request = append(request, []byte(password)...)
|
||||
|
||||
if _, err := conn.Write(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, response); err != nil {
|
||||
return err
|
||||
}
|
||||
if response[1] != 0x00 {
|
||||
return errors.New("socks5 username/password authentication failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
|
||||
host, portString, err := net.SplitHostPort(targetAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid port %d", port)
|
||||
}
|
||||
|
||||
request := []byte{0x05, 0x01, 0x00}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
request = append(request, 0x01)
|
||||
request = append(request, ip4...)
|
||||
} else {
|
||||
request = append(request, 0x04)
|
||||
request = append(request, ip.To16()...)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return errors.New("socks5 target host too long")
|
||||
}
|
||||
request = append(request, 0x03, byte(len(host)))
|
||||
request = append(request, []byte(host)...)
|
||||
}
|
||||
|
||||
request = append(request, byte(port>>8), byte(port))
|
||||
_, err = conn.Write(request)
|
||||
return err
|
||||
}
|
||||
|
||||
func readSOCKS5ConnectResponse(conn net.Conn) error {
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return fmt.Errorf("invalid socks5 response version %d", header[0])
|
||||
}
|
||||
if header[1] != 0x00 {
|
||||
return fmt.Errorf("socks5 connect failed with code %d", header[1])
|
||||
}
|
||||
|
||||
switch header[3] {
|
||||
case 0x01:
|
||||
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
|
||||
return err
|
||||
}
|
||||
case 0x03:
|
||||
size := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, size); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
|
||||
return err
|
||||
}
|
||||
case 0x04:
|
||||
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
|
||||
}
|
||||
|
||||
_, err := io.ReadFull(conn, make([]byte, 2))
|
||||
return err
|
||||
}
|
||||
|
||||
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
|
||||
if conn == nil {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
var (
|
||||
deadline time.Time
|
||||
hasValue bool
|
||||
)
|
||||
if ctx != nil {
|
||||
if ctxDeadline, ok := ctx.Deadline(); ok {
|
||||
deadline = ctxDeadline
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
if timeout > 0 {
|
||||
timeoutDeadline := time.Now().Add(timeout)
|
||||
if !hasValue || timeoutDeadline.Before(deadline) {
|
||||
deadline = timeoutDeadline
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
if !hasValue {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
_ = conn.SetDeadline(deadline)
|
||||
return func() {
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
|
||||
if proxy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := *proxy
|
||||
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
|
||||
normalized.Addr = strings.TrimSpace(normalized.Addr)
|
||||
if normalized.Timeout <= 0 {
|
||||
normalized.Timeout = defaultTimeout
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
func joinHostPort(host string, port int) string {
|
||||
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
|
||||
}
|
||||
196
types.go
Normal file
196
types.go
Normal file
@ -0,0 +1,196 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSSHPort = 22
|
||||
defaultLoginTimeout = 5 * time.Second
|
||||
defaultKeepAliveTimeout = 3 * time.Second
|
||||
defaultShellPollInterval = 120 * time.Millisecond
|
||||
defaultShellSetupDelay = 200 * time.Millisecond
|
||||
defaultShellSetupTimeout = 3 * time.Second
|
||||
defaultShellWaitTimeout = 30 * time.Second
|
||||
defaultShellPromptToken = "__STARSSH_PROMPT__>"
|
||||
defaultPTYTerm = "xterm"
|
||||
defaultPTYRows = 500
|
||||
defaultPTYColumns = 250
|
||||
|
||||
defaultTransferBufferSize = 1024 * 1024
|
||||
|
||||
defaultExecStreamMaxPendingChunks = 256
|
||||
defaultExecStreamMaxPendingBytes = 4 * 1024 * 1024
|
||||
)
|
||||
|
||||
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
type ProxyType string
|
||||
|
||||
const (
|
||||
ProxyTypeSOCKS5 ProxyType = "socks5"
|
||||
ProxyTypeHTTPConnect ProxyType = "http_connect"
|
||||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Type ProxyType
|
||||
Addr string
|
||||
Username string
|
||||
Password string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type AuthMethodKind string
|
||||
|
||||
const (
|
||||
AuthMethodPrivateKey AuthMethodKind = "private_key"
|
||||
AuthMethodPassword AuthMethodKind = "password"
|
||||
AuthMethodKeyboardInteractive AuthMethodKind = "keyboard_interactive"
|
||||
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
|
||||
)
|
||||
|
||||
type StarSSH struct {
|
||||
stateMu sync.RWMutex
|
||||
Client *ssh.Client
|
||||
PublicKey ssh.PublicKey
|
||||
PubkeyBase64 string
|
||||
Hostname string
|
||||
RemoteAddr net.Addr
|
||||
Banner string
|
||||
LoginInfo LoginInput
|
||||
online bool
|
||||
upstream *StarSSH
|
||||
sftpClient *sftp.Client
|
||||
sftpMu sync.Mutex
|
||||
keepaliveMu sync.Mutex
|
||||
keepaliveStop chan struct{}
|
||||
keepaliveDone chan struct{}
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
KeyExchanges []string
|
||||
Ciphers []string
|
||||
MACs []string
|
||||
User string
|
||||
Password string
|
||||
PasswordCallback func() (string, error)
|
||||
KeyboardInteractiveCallback ssh.KeyboardInteractiveChallenge
|
||||
Prikey string
|
||||
PrikeyPwd string
|
||||
DisableSSHAgent bool
|
||||
AuthOrder []AuthMethodKind
|
||||
Addr string
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
DialContext DialContextFunc
|
||||
Proxy *ProxyConfig
|
||||
Jump *LoginInput
|
||||
KeepAliveInterval time.Duration
|
||||
KeepAliveTimeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
}
|
||||
|
||||
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
||||
// It is not a generic cross-shell abstraction; for product-grade interactive terminals, prefer TerminalSession.
|
||||
type StarShell struct {
|
||||
Keyword string
|
||||
UseWaitDefault bool
|
||||
WaitTimeout time.Duration
|
||||
Session *ssh.Session
|
||||
in io.Writer
|
||||
out *bufio.Reader
|
||||
er *bufio.Reader
|
||||
outbyte []byte
|
||||
errbyte []byte
|
||||
lastout int64
|
||||
errors error
|
||||
isprint bool
|
||||
isfuncs bool
|
||||
iscolor bool
|
||||
isecho bool
|
||||
rw sync.RWMutex
|
||||
funcs func(string)
|
||||
writeMu sync.Mutex
|
||||
commandMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
promptToken string
|
||||
}
|
||||
|
||||
type TerminalConfig struct {
|
||||
Term string
|
||||
Rows int
|
||||
Columns int
|
||||
Modes ssh.TerminalModes
|
||||
}
|
||||
|
||||
type TerminalControl byte
|
||||
|
||||
const (
|
||||
TerminalControlInterrupt TerminalControl = 0x03
|
||||
TerminalControlEOF TerminalControl = 0x04
|
||||
TerminalControlBell TerminalControl = 0x07
|
||||
TerminalControlBackspace TerminalControl = 0x08
|
||||
TerminalControlLineKill TerminalControl = 0x15
|
||||
TerminalControlQuit TerminalControl = 0x1c
|
||||
TerminalControlSuspend TerminalControl = 0x1a
|
||||
TerminalControlPauseOutput TerminalControl = 0x13
|
||||
TerminalControlResumeOutput TerminalControl = 0x11
|
||||
)
|
||||
|
||||
type TerminalCloseReason string
|
||||
|
||||
const (
|
||||
TerminalCloseReasonUnknown TerminalCloseReason = ""
|
||||
TerminalCloseReasonExit TerminalCloseReason = "exit"
|
||||
TerminalCloseReasonSignal TerminalCloseReason = "signal"
|
||||
TerminalCloseReasonClosed TerminalCloseReason = "closed"
|
||||
TerminalCloseReasonContextCanceled TerminalCloseReason = "context_canceled"
|
||||
TerminalCloseReasonDeadlineExceeded TerminalCloseReason = "deadline_exceeded"
|
||||
TerminalCloseReasonTransportError TerminalCloseReason = "transport_error"
|
||||
)
|
||||
|
||||
type TerminalExitInfo struct {
|
||||
ExitCode int
|
||||
ExitSignal string
|
||||
ExitMessage string
|
||||
Reason TerminalCloseReason
|
||||
}
|
||||
|
||||
type TerminalSession struct {
|
||||
Session *ssh.Session
|
||||
ID string
|
||||
Label string
|
||||
Metadata map[string]string
|
||||
stdin io.WriteCloser
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
|
||||
attachMu sync.RWMutex
|
||||
in io.Reader
|
||||
out io.Writer
|
||||
errOut io.Writer
|
||||
|
||||
runOnce sync.Once
|
||||
runDone chan struct{}
|
||||
runErr error
|
||||
|
||||
waitOnce sync.Once
|
||||
waitDone chan struct{}
|
||||
|
||||
stateMu sync.RWMutex
|
||||
waitErr error
|
||||
exitInfo TerminalExitInfo
|
||||
closeReason TerminalCloseReason
|
||||
closeErr error
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
162
utils.go
Normal file
162
utils.go
Normal file
@ -0,0 +1,162 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ansiCSIRegexp = regexp.MustCompile(`\x1b\[[0-9;?]*[ -/]*[@-~]`)
|
||||
ansiOSCRegexp = regexp.MustCompile(`\x1b\][^\x07]*(\x07|\x1b\\)`)
|
||||
leadingIntRegexp = regexp.MustCompile(`^[+-]?\d+`)
|
||||
)
|
||||
|
||||
func SedColor(str string) string {
|
||||
return stripControlSequences(str)
|
||||
}
|
||||
|
||||
func normalizeShellOutput(raw string) string {
|
||||
return strings.TrimSpace(strings.ReplaceAll(raw, "\r\n", "\n"))
|
||||
}
|
||||
|
||||
func stripControlSequences(raw string) string {
|
||||
cleaned := ansiOSCRegexp.ReplaceAllString(raw, "")
|
||||
cleaned = ansiCSIRegexp.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "\r", "")
|
||||
cleaned = strings.Map(func(r rune) rune {
|
||||
if r == '\n' || r == '\t' {
|
||||
return r
|
||||
}
|
||||
if r < 0x20 || r == 0x7f {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func stripLeadingEcho(output string, command string, markerCommand string) string {
|
||||
result := output
|
||||
if command == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
withMarker := command
|
||||
if markerCommand != "" {
|
||||
withMarker += "\n" + markerCommand
|
||||
}
|
||||
if strings.HasPrefix(result, withMarker) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(result, withMarker))
|
||||
}
|
||||
if strings.HasPrefix(result, command) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(result, command))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectLinesWithoutTokens(output string, tokens ...string) string {
|
||||
lines := strings.Split(output, "\n")
|
||||
filtered := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
skip := false
|
||||
for _, token := range tokens {
|
||||
if token != "" && strings.Contains(line, token) {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(filtered, "\n"))
|
||||
}
|
||||
|
||||
func collectLinesForCommandOutput(output string, promptToken string, tokens ...string) string {
|
||||
lines := strings.Split(output, "\n")
|
||||
filtered := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
skip := false
|
||||
for _, token := range tokens {
|
||||
if token != "" && strings.Contains(line, token) {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip && promptToken != "" && strings.Contains(line, promptToken) {
|
||||
skip = true
|
||||
}
|
||||
if !skip {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(filtered, "\n"))
|
||||
}
|
||||
|
||||
func shellSingleQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
|
||||
func normalizeBufferSize(bufcap int) int {
|
||||
if bufcap <= 0 {
|
||||
return defaultTransferBufferSize
|
||||
}
|
||||
return bufcap
|
||||
}
|
||||
|
||||
func newCommandTokens() (beginToken string, endToken string) {
|
||||
nonce := newNonce(8)
|
||||
return "__STARSSH_BEGIN_" + nonce + "__", "__STARSSH_END_" + nonce + "__"
|
||||
}
|
||||
|
||||
func newNonce(size int) string {
|
||||
if size <= 0 {
|
||||
size = 8
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(buf))
|
||||
}
|
||||
|
||||
func splitByEndToken(output string, endToken string) (before string, exitCode int, found bool, parseErr error) {
|
||||
prefix := endToken + ":"
|
||||
lines := strings.Split(output, "\n")
|
||||
beforeLines := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(trimmedLine, prefix) {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
codeText := strings.TrimSpace(trimmedLine[len(prefix):])
|
||||
match := leadingIntRegexp.FindString(codeText)
|
||||
if match == "" || match != codeText {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
code, err := strconv.Atoi(match)
|
||||
if err != nil {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
return strings.Join(beforeLines, "\n"), code, true, nil
|
||||
}
|
||||
|
||||
return strings.Join(beforeLines, "\n"), 0, false, nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user