From f20eb653ae69e04aadd4a2f6fbfe3eeb56432e09 Mon Sep 17 00:00:00 2001 From: starainrt Date: Sun, 26 Apr 2026 10:45:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=20starssh=20?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E8=BF=90=E8=A1=8C=E6=97=B6=E5=B9=B6=E8=A1=A5?= =?UTF-8?q?=E5=BC=BA=20ssh/exec/terminal/sftp=20=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性 --- .gitignore | 9 + LICENSE | 201 ++++ cancel_semantics_test.go | 172 ++++ exec.go | 1218 ++++++++++++++++++++++++ exec_legacy_test.go | 69 ++ forward.go | 694 ++++++++++++++ forward_test.go | 164 ++++ go.mod | 15 +- go.sum | 93 +- hostkey.go | 290 ++++++ keepalive.go | 135 +++ keepalive_test.go | 53 ++ login.go | 362 ++++++++ pool.go | 466 ++++++++++ session.go | 121 +++ sftp.go | 1603 +++++++++++++++++++++++++++++--- sftp_test.go | 475 ++++++++++ shell.go | 592 ++++++++++++ ssh.go | 636 ------------- sshagent_unix.go | 21 + sshagent_windows.go | 70 ++ state.go | 106 +++ terminal.go | 431 +++++++++ terminal_input.go | 225 +++++ terminal_input_adapter_test.go | 49 + terminal_input_test.go | 290 ++++++ terminal_input_unix.go | 21 + terminal_input_windows.go | 32 + transport.go | 336 +++++++ types.go | 196 ++++ utils.go | 162 ++++ 31 files changed, 8538 insertions(+), 769 deletions(-) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 cancel_semantics_test.go create mode 100644 exec.go create mode 100644 exec_legacy_test.go create mode 100644 forward.go create mode 100644 forward_test.go create mode 100644 hostkey.go create mode 100644 keepalive.go create mode 100644 keepalive_test.go create mode 100644 login.go create mode 100644 pool.go create mode 100644 session.go create mode 100644 sftp_test.go create mode 100644 shell.go delete mode 100644 ssh.go create mode 100644 sshagent_unix.go create mode 100644 sshagent_windows.go create mode 100644 state.go create mode 100644 terminal.go create mode 100644 terminal_input.go create mode 100644 terminal_input_adapter_test.go create mode 100644 terminal_input_test.go create mode 100644 terminal_input_unix.go create mode 100644 terminal_input_windows.go create mode 100644 transport.go create mode 100644 types.go create mode 100644 utils.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f7de56 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.sentrux/ +agent_readme.md +target.md +.gocache/ +.tmp_*/ +.codex/ +.idea/ +agents.md +.codex \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9590d39 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/cancel_semantics_test.go b/cancel_semantics_test.go new file mode 100644 index 0000000..fe7ab5e --- /dev/null +++ b/cancel_semantics_test.go @@ -0,0 +1,172 @@ +package starssh + +import ( + "context" + "errors" + "net" + "sync/atomic" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) { + oldSendKeepAliveRequest := sendKeepAliveRequest + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + sendKeepAliveRequest = oldSendKeepAliveRequest + closeSSHClient = oldCloseSSHClient + }) + + sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { + <-ctx.Done() + time.Sleep(20 * time.Millisecond) + return ctx.Err() + } + + var closeCalls atomic.Int32 + closeSSHClient = func(client sshClientRequester) error { + closeCalls.Add(1) + return nil + } + + client := &ssh.Client{} + star := &StarSSH{} + star.setTransport(client, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + err := star.PingContext(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded, got %v", err) + } + if closeCalls.Load() != 0 { + t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load()) + } + if got := star.snapshotSSHClient(); got != client { + t.Fatal("expected ssh client to remain attached after PingContext cancel") + } +} + +func TestPingContextCloseOnCancelClosesConnection(t *testing.T) { + oldSendKeepAliveRequest := sendKeepAliveRequest + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + sendKeepAliveRequest = oldSendKeepAliveRequest + closeSSHClient = oldCloseSSHClient + }) + + sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { + <-ctx.Done() + time.Sleep(20 * time.Millisecond) + return ctx.Err() + } + + var closeCalls atomic.Int32 + closeSSHClient = func(client sshClientRequester) error { + closeCalls.Add(1) + return nil + } + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + err := star.PingContextCloseOnCancel(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded, got %v", err) + } + if closeCalls.Load() != 1 { + t.Fatalf("expected exactly one close call, got %d", closeCalls.Load()) + } + if got := star.snapshotSSHClient(); got != nil { + t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel") + } +} + +func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) { + oldDialSSHClient := dialSSHClient + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + closeSSHClient = oldCloseSSHClient + }) + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + <-ctx.Done() + time.Sleep(20 * time.Millisecond) + return nil, ctx.Err() + } + + var closeCalls atomic.Int32 + closeSSHClient = func(client sshClientRequester) error { + closeCalls.Add(1) + return nil + } + + client := &ssh.Client{} + star := &StarSSH{} + star.setTransport(client, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22") + if conn != nil { + t.Fatal("expected nil connection on canceled dial") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded, got %v", err) + } + if closeCalls.Load() != 0 { + t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load()) + } + if got := star.snapshotSSHClient(); got != client { + t.Fatal("expected ssh client to remain attached after DialTCPContext cancel") + } +} + +func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) { + oldDialSSHClient := dialSSHClient + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + closeSSHClient = oldCloseSSHClient + }) + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + <-ctx.Done() + time.Sleep(20 * time.Millisecond) + return nil, ctx.Err() + } + + var closeCalls atomic.Int32 + closeSSHClient = func(client sshClientRequester) error { + closeCalls.Add(1) + return nil + } + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22") + if conn != nil { + t.Fatal("expected nil connection on canceled dial") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded, got %v", err) + } + if closeCalls.Load() != 1 { + t.Fatalf("expected exactly one close call, got %d", closeCalls.Load()) + } + if got := star.snapshotSSHClient(); got != nil { + t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel") + } +} diff --git a/exec.go b/exec.go new file mode 100644 index 0000000..5b9d393 --- /dev/null +++ b/exec.go @@ -0,0 +1,1218 @@ +package starssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "sort" + "strconv" + "strings" + "sync" + "time" + "unicode" + + "golang.org/x/crypto/ssh" +) + +type ExecRequest struct { + Command string + Stdin []byte + Env map[string]string + Dir string + ShellDialect ExecShellDialect + Timeout time.Duration + PTY *TerminalConfig + DiscardOutput bool + MaxOutputBytes int + StreamMaxPendingChunks int + StreamMaxPendingBytes int + StreamOverflowStrategy ExecStreamOverflowStrategy +} + +type ExecResult struct { + Command string + Stdout []byte + Stderr []byte + Combined []byte + StdoutTruncated bool + StderrTruncated bool + CombinedTruncated bool + StreamDroppedChunks int + StreamDroppedBytes int + ExitCode int + ExitSignal string + ExitMessage string + Duration time.Duration +} + +type ExecStreamChunk struct { + Data []byte + Stderr bool +} + +type ExecShellDialect string + +const ( + ExecShellDialectPOSIX ExecShellDialect = "posix" + ExecShellDialectPowerShell ExecShellDialect = "powershell" + ExecShellDialectCMD ExecShellDialect = "cmd" + ExecShellDialectRaw ExecShellDialect = "raw" +) + +type ExecStreamOverflowStrategy string + +const ( + ExecStreamOverflowDropOldest ExecStreamOverflowStrategy = "drop_oldest" + ExecStreamOverflowDropNewest ExecStreamOverflowStrategy = "drop_newest" + ExecStreamOverflowFail ExecStreamOverflowStrategy = "fail" +) + +type ExecExitError struct { + Status int + Signal string + Message string + Stderr string +} + +type ExecStreamOverflowError struct { + DroppedChunks int + DroppedBytes int +} + +type ShellExitError = ExecExitError + +var execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) { + return s.Exec(ctx, req) +} + +func (e *ExecExitError) Error() string { + if e == nil { + return "" + } + + base := "remote command exited" + if e.Status != 0 { + base += " with status " + strconv.Itoa(e.Status) + } + if e.Signal != "" { + base += " from signal " + e.Signal + } + if e.Message != "" { + base += ": " + e.Message + } + if e.Stderr != "" { + base += ": " + e.Stderr + } + return base +} + +func (e *ExecExitError) ExitStatus() int { + if e == nil { + return 0 + } + return e.Status +} + +func (e *ExecStreamOverflowError) Error() string { + if e == nil { + return "" + } + return fmt.Sprintf("exec stream callback queue overflow: dropped %d chunks (%d bytes)", e.DroppedChunks, e.DroppedBytes) +} + +func (r *ExecResult) Success() bool { + return r != nil && r.ExitCode == 0 && r.ExitSignal == "" +} + +func (r *ExecResult) StdoutString() string { + if r == nil { + return "" + } + return string(r.Stdout) +} + +func (r *ExecResult) StderrString() string { + if r == nil { + return "" + } + return string(r.Stderr) +} + +func (r *ExecResult) CombinedString() string { + if r == nil { + return "" + } + return string(r.Combined) +} + +func (r *ExecResult) CommandError() error { + if r == nil || r.Success() { + return nil + } + return &ExecExitError{ + Status: r.ExitCode, + Signal: r.ExitSignal, + Message: strings.TrimSpace(r.ExitMessage), + Stderr: strings.TrimSpace(r.StderrString()), + } +} + +func (r *ExecResult) OutputTruncated() bool { + if r == nil { + return false + } + return r.StdoutTruncated || r.StderrTruncated || r.CombinedTruncated +} + +func (r *ExecResult) StreamOutputDropped() bool { + return r != nil && (r.StreamDroppedChunks > 0 || r.StreamDroppedBytes > 0) +} + +func (r *ExecResult) StreamOverflowError() error { + if r == nil || !r.StreamOutputDropped() { + return nil + } + return &ExecStreamOverflowError{ + DroppedChunks: r.StreamDroppedChunks, + DroppedBytes: r.StreamDroppedBytes, + } +} + +func (s *StarSSH) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { + return s.exec(ctx, req, nil) +} + +func (s *StarSSH) ExecString(ctx context.Context, command string) (*ExecResult, error) { + return s.Exec(ctx, ExecRequest{ + Command: command, + }) +} + +func (s *StarSSH) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { + return s.exec(ctx, req, onChunk) +} + +func (s *StarSSH) exec(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { + if s == nil { + return nil, errors.New("ssh client is nil") + } + if ctx == nil { + ctx = context.Background() + } + if req.Timeout > 0 { + timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout) + defer cancel() + ctx = timeoutCtx + } + + remoteCommand, err := buildExecCommand(req) + if err != nil { + return nil, err + } + + session, err := s.newExecRuntimeSession(req) + if err != nil { + return nil, err + } + defer session.Close() + + var stdin io.WriteCloser + if req.Stdin != nil { + stdin, err = session.StdinPipe() + if err != nil { + return nil, err + } + } + + stdout, err := session.StdoutPipe() + if err != nil { + return nil, err + } + stderr, err := session.StderrPipe() + if err != nil { + return nil, err + } + + result := &ExecResult{ + Command: req.Command, + } + + startAt := time.Now() + if err := session.Start(remoteCommand); err != nil { + return nil, err + } + + if stdin != nil { + go func() { + _, _ = stdin.Write(req.Stdin) + _ = stdin.Close() + }() + } + + chunks := make(chan ExecStreamChunk, 16) + readErrs := make(chan error, 2) + + var readWG sync.WaitGroup + readWG.Add(2) + go streamExecReader(stdout, false, chunks, readErrs, &readWG) + go streamExecReader(stderr, true, chunks, readErrs, &readWG) + go func() { + readWG.Wait() + close(chunks) + close(readErrs) + }() + + stdoutBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) + stderrBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) + combinedBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) + dispatcher, err := newExecChunkDispatcher(req, onChunk) + if err != nil { + return nil, err + } + drainDone := make(chan struct{}) + go func() { + defer close(drainDone) + defer dispatcher.Close() + for chunk := range chunks { + if len(chunk.Data) == 0 { + continue + } + _, _ = combinedBuf.Write(chunk.Data) + if chunk.Stderr { + _, _ = stderrBuf.Write(chunk.Data) + } else { + _, _ = stdoutBuf.Write(chunk.Data) + } + dispatcher.Enqueue(chunk) + } + }() + + waitCh := make(chan error, 1) + go func() { + waitCh <- session.Wait() + }() + + var waitErr error + select { + case waitErr = <-waitCh: + case <-ctx.Done(): + _ = session.Close() + waitErr = ctx.Err() + } + + <-drainDone + dispatchStats := dispatcher.Wait() + readErr := firstExecError(readErrs) + + result.Stdout = append(result.Stdout[:0], stdoutBuf.Bytes()...) + result.Stderr = append(result.Stderr[:0], stderrBuf.Bytes()...) + result.Combined = append(result.Combined[:0], combinedBuf.Bytes()...) + result.StdoutTruncated = stdoutBuf.Truncated() + result.StderrTruncated = stderrBuf.Truncated() + result.CombinedTruncated = combinedBuf.Truncated() + result.StreamDroppedChunks = dispatchStats.droppedChunks + result.StreamDroppedBytes = dispatchStats.droppedBytes + result.Duration = time.Since(startAt) + + if errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded) { + return result, waitErr + } + + var exitErr *ssh.ExitError + if errors.As(waitErr, &exitErr) { + result.ExitCode = exitErr.ExitStatus() + result.ExitSignal = exitErr.Signal() + result.ExitMessage = exitErr.Msg() + waitErr = nil + } + + if readErr != nil { + return result, readErr + } + if dispatchStats.err != nil { + return result, dispatchStats.err + } + + if waitErr == nil { + return result, nil + } + + return result, waitErr +} + +func (s *StarSSH) newExecRuntimeSession(req ExecRequest) (*ssh.Session, error) { + if req.PTY != nil { + return s.NewPTYSession(req.PTY) + } + return s.NewExecSession() +} + +func buildExecCommand(req ExecRequest) (string, error) { + if strings.TrimSpace(req.Command) == "" { + return "", errors.New("command is empty") + } + + dialect, err := normalizeExecShellDialect(req.ShellDialect) + if err != nil { + return "", err + } + + switch dialect { + case ExecShellDialectPOSIX: + return buildExecCommandPOSIX(req) + case ExecShellDialectPowerShell: + return buildExecCommandPowerShell(req) + case ExecShellDialectCMD: + return buildExecCommandCMD(req) + case ExecShellDialectRaw: + return buildExecCommandRaw(req) + default: + return "", fmt.Errorf("unsupported exec shell dialect %q", req.ShellDialect) + } +} + +func normalizeExecShellDialect(dialect ExecShellDialect) (ExecShellDialect, error) { + if strings.TrimSpace(string(dialect)) == "" { + return ExecShellDialectPOSIX, nil + } + + switch ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))) { + case ExecShellDialectPOSIX, ExecShellDialectPowerShell, ExecShellDialectCMD, ExecShellDialectRaw: + return ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))), nil + default: + return "", fmt.Errorf("invalid exec shell dialect %q", dialect) + } +} + +func buildExecCommandPOSIX(req ExecRequest) (string, error) { + parts := make([]string, 0, 3) + if strings.TrimSpace(req.Dir) != "" { + parts = append(parts, "cd "+shellSingleQuote(req.Dir)) + } + + if len(req.Env) > 0 { + keys := make([]string, 0, len(req.Env)) + for key := range req.Env { + if !isValidShellEnvKey(key) { + return "", fmt.Errorf("invalid env key %q", key) + } + keys = append(keys, key) + } + sort.Strings(keys) + + assignments := make([]string, 0, len(keys)) + for _, key := range keys { + assignments = append(assignments, key+"="+shellSingleQuote(req.Env[key])) + } + parts = append(parts, "export "+strings.Join(assignments, " ")) + } + + parts = append(parts, req.Command) + return strings.Join(parts, " && "), nil +} + +func buildExecCommandPowerShell(req ExecRequest) (string, error) { + parts := make([]string, 0, 2+len(req.Env)) + parts = append(parts, "$ErrorActionPreference = 'Stop'") + + if strings.TrimSpace(req.Dir) != "" { + parts = append(parts, "Set-Location -LiteralPath "+powerShellSingleQuote(req.Dir)) + } + + if len(req.Env) > 0 { + keys := make([]string, 0, len(req.Env)) + for key := range req.Env { + if !isValidShellEnvKey(key) { + return "", fmt.Errorf("invalid env key %q", key) + } + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + parts = append(parts, "$env:"+key+" = "+powerShellSingleQuote(req.Env[key])) + } + } + + parts = append(parts, req.Command) + return strings.Join(parts, "; "), nil +} + +func buildExecCommandCMD(req ExecRequest) (string, error) { + keys := make([]string, 0, len(req.Env)) + for key := range req.Env { + if !isValidShellEnvKey(key) { + return "", fmt.Errorf("invalid env key %q", key) + } + keys = append(keys, key) + } + sort.Strings(keys) + + replacements := make(map[string]string, len(keys)+1) + if strings.TrimSpace(req.Dir) != "" { + replacements["CD"] = "!CD!" + } + for _, key := range keys { + replacements[strings.ToUpper(key)] = "!" + key + "!" + } + + command, rewrotePercentVars := rewriteCMDPercentVariables(req.Command, replacements) + needsDelayedExpansion := rewrotePercentVars + if strings.TrimSpace(req.Dir) != "" && cmdContainsBangVariable(command, "CD") { + needsDelayedExpansion = true + } + for _, key := range keys { + if cmdContainsBangVariable(command, key) { + needsDelayedExpansion = true + } + } + + parts := make([]string, 0, 3+len(keys)) + if len(keys) > 0 { + parts = append(parts, "setlocal DisableDelayedExpansion") + for _, key := range keys { + parts = append(parts, "set "+key+"="+cmdEscapeForSetValue(req.Env[key])) + } + } + + if strings.TrimSpace(req.Dir) != "" { + parts = append(parts, "cd /d "+cmdEscapeForBareArgument(req.Dir, true)) + } + + if needsDelayedExpansion { + parts = append(parts, wrapCMDCommand(command)) + } else { + parts = append(parts, command) + } + return strings.Join(parts, " && "), nil +} + +func buildExecCommandRaw(req ExecRequest) (string, error) { + if strings.TrimSpace(req.Dir) != "" { + return "", errors.New("raw exec shell dialect does not support Dir") + } + if len(req.Env) > 0 { + return "", errors.New("raw exec shell dialect does not support Env") + } + return req.Command, nil +} + +func combineCommandOutput(stdout string, stderr string) []byte { + if stdout == "" && stderr == "" { + return nil + } + if stdout == "" { + return []byte(stderr) + } + if stderr == "" { + return []byte(stdout) + } + return []byte(stdout + "\n" + stderr) +} + +func powerShellSingleQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + +func wrapCMDCommand(script string) string { + return `cmd.exe /Q /D /V:ON /C ` + cmdEscapeForNestedCommand(script) +} + +func rewriteCMDPercentVariables(command string, replacements map[string]string) (string, bool) { + if len(replacements) == 0 || command == "" { + return command, false + } + + var builder strings.Builder + builder.Grow(len(command)) + + rewrote := false + for i := 0; i < len(command); { + if command[i] != '%' { + builder.WriteByte(command[i]) + i++ + continue + } + + end := strings.IndexByte(command[i+1:], '%') + if end < 0 { + builder.WriteByte(command[i]) + i++ + continue + } + + end += i + 1 + name := command[i+1 : end] + if name == "" { + builder.WriteString(command[i : end+1]) + i = end + 1 + continue + } + + replacement, ok := replacements[strings.ToUpper(name)] + if !ok { + builder.WriteString(command[i : end+1]) + i = end + 1 + continue + } + + builder.WriteString(replacement) + rewrote = true + i = end + 1 + } + + return builder.String(), rewrote +} + +func cmdContainsBangVariable(command string, name string) bool { + if command == "" || name == "" { + return false + } + return strings.Contains(strings.ToUpper(command), "!"+strings.ToUpper(name)+"!") +} + +func cmdEscapeForSetValue(value string) string { + var builder strings.Builder + builder.Grow(len(value)) + for _, char := range value { + switch char { + case '^': + builder.WriteString("^^") + case '&', '|', '<', '>', '(', ')', '"': + builder.WriteByte('^') + builder.WriteRune(char) + case '%': + builder.WriteString("%%") + default: + builder.WriteRune(char) + } + } + return builder.String() +} + +func cmdEscapeForBareArgument(value string, escapeSpace bool) string { + var builder strings.Builder + builder.Grow(len(value)) + for _, char := range value { + switch char { + case '^': + builder.WriteString("^^") + case '&', '|', '<', '>', '(', ')', '"': + builder.WriteByte('^') + builder.WriteRune(char) + case '%': + builder.WriteString("%%") + case ' ': + if escapeSpace { + builder.WriteString("^ ") + } else { + builder.WriteRune(char) + } + default: + builder.WriteRune(char) + } + } + return builder.String() +} + +func cmdEscapeForNestedCommand(command string) string { + return cmdEscapeForBareArgument(command, false) +} + +type captureBuffer struct { + limit int + discard bool + buffer bytes.Buffer + truncated bool +} + +func newCaptureBuffer(limit int, discard bool) *captureBuffer { + return &captureBuffer{ + limit: limit, + discard: discard, + } +} + +func (b *captureBuffer) Write(data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } + if b == nil { + return len(data), nil + } + if b.discard { + return len(data), nil + } + if b.limit <= 0 { + _, _ = b.buffer.Write(data) + return len(data), nil + } + + remaining := b.limit - b.buffer.Len() + if remaining <= 0 { + b.truncated = true + return len(data), nil + } + + if len(data) > remaining { + _, _ = b.buffer.Write(data[:remaining]) + b.truncated = true + return len(data), nil + } + + _, _ = b.buffer.Write(data) + return len(data), nil +} + +func (b *captureBuffer) Bytes() []byte { + if b == nil { + return nil + } + return b.buffer.Bytes() +} + +func (b *captureBuffer) Truncated() bool { + if b == nil { + return false + } + return b.truncated +} + +func isValidShellEnvKey(key string) bool { + if key == "" { + return false + } + for i, r := range key { + if i == 0 { + if r != '_' && !unicode.IsLetter(r) { + return false + } + continue + } + if r != '_' && !unicode.IsLetter(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func streamExecReader(reader io.Reader, isStderr bool, chunks chan<- ExecStreamChunk, errCh chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + chunks <- ExecStreamChunk{ + Data: chunk, + Stderr: isStderr, + } + } + + if err == io.EOF { + return + } + if err != nil { + errCh <- err + return + } + } +} + +type execChunkDispatcher struct { + onChunk func(ExecStreamChunk) + done chan struct{} + + mu sync.Mutex + cond *sync.Cond + queue []ExecStreamChunk + queueBytes int + maxChunks int + maxBytes int + strategy ExecStreamOverflowStrategy + closed bool + stopped bool + failed bool + droppedBytes int + droppedCount int +} + +type execChunkDispatchStats struct { + droppedChunks int + droppedBytes int + err error +} + +func newExecChunkDispatcher(req ExecRequest, onChunk func(ExecStreamChunk)) (*execChunkDispatcher, error) { + if onChunk == nil { + return nil, nil + } + + config, err := normalizeExecChunkDispatchConfig(req) + if err != nil { + return nil, err + } + + dispatcher := &execChunkDispatcher{ + onChunk: onChunk, + done: make(chan struct{}), + maxChunks: config.maxChunks, + maxBytes: config.maxBytes, + strategy: config.strategy, + } + dispatcher.cond = sync.NewCond(&dispatcher.mu) + go dispatcher.run() + return dispatcher, nil +} + +func (d *execChunkDispatcher) Enqueue(chunk ExecStreamChunk) { + if d == nil || len(chunk.Data) == 0 { + return + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.closed || d.stopped { + d.recordDropLocked(chunk) + return + } + + switch d.strategy { + case ExecStreamOverflowDropOldest: + for len(d.queue) > 0 && d.wouldOverflowLocked(chunk) { + d.recordDropLocked(d.popOldestLocked()) + } + if d.wouldOverflowLocked(chunk) { + d.recordDropLocked(chunk) + return + } + case ExecStreamOverflowDropNewest: + if d.wouldOverflowLocked(chunk) { + d.recordDropLocked(chunk) + return + } + case ExecStreamOverflowFail: + if d.wouldOverflowLocked(chunk) { + d.recordDropLocked(chunk) + d.stopWithOverflowLocked() + return + } + } + + d.queue = append(d.queue, chunk) + d.queueBytes += len(chunk.Data) + d.cond.Signal() +} + +func (d *execChunkDispatcher) Close() { + if d == nil { + return + } + + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return + } + d.closed = true + d.cond.Broadcast() + d.mu.Unlock() +} + +func (d *execChunkDispatcher) run() { + defer close(d.done) + + for { + chunk, ok := d.next() + if !ok { + return + } + d.onChunk(chunk) + } +} + +func (d *execChunkDispatcher) next() (ExecStreamChunk, bool) { + d.mu.Lock() + defer d.mu.Unlock() + + for len(d.queue) == 0 && !d.closed && !d.stopped { + d.cond.Wait() + } + if len(d.queue) == 0 { + return ExecStreamChunk{}, false + } + + chunk := d.popOldestLocked() + return chunk, true +} + +func (d *execChunkDispatcher) Wait() execChunkDispatchStats { + if d == nil { + return execChunkDispatchStats{} + } + <-d.done + + d.mu.Lock() + defer d.mu.Unlock() + return execChunkDispatchStats{ + droppedChunks: d.droppedCount, + droppedBytes: d.droppedBytes, + err: d.dispatchErrorLocked(), + } +} + +func (d *execChunkDispatcher) wouldOverflowLocked(chunk ExecStreamChunk) bool { + if d.maxChunks > 0 && len(d.queue)+1 > d.maxChunks { + return true + } + if d.maxBytes > 0 && d.queueBytes+len(chunk.Data) > d.maxBytes { + return true + } + return false +} + +func (d *execChunkDispatcher) popOldestLocked() ExecStreamChunk { + if len(d.queue) == 0 { + return ExecStreamChunk{} + } + chunk := d.queue[0] + d.queue[0] = ExecStreamChunk{} + d.queue = d.queue[1:] + d.queueBytes -= len(chunk.Data) + if d.queueBytes < 0 { + d.queueBytes = 0 + } + if len(d.queue) == 0 { + d.queue = nil + } + return chunk +} + +func (d *execChunkDispatcher) recordDropLocked(chunk ExecStreamChunk) { + if len(chunk.Data) == 0 { + return + } + d.droppedCount++ + d.droppedBytes += len(chunk.Data) +} + +func (d *execChunkDispatcher) stopWithOverflowLocked() { + if d.stopped { + return + } + for len(d.queue) > 0 { + d.recordDropLocked(d.popOldestLocked()) + } + d.stopped = true + d.failed = true + d.cond.Broadcast() +} + +func (d *execChunkDispatcher) dispatchErrorLocked() error { + if !d.failed { + return nil + } + return &ExecStreamOverflowError{ + DroppedChunks: d.droppedCount, + DroppedBytes: d.droppedBytes, + } +} + +type execChunkDispatchConfig struct { + maxChunks int + maxBytes int + strategy ExecStreamOverflowStrategy +} + +func normalizeExecChunkDispatchConfig(req ExecRequest) (execChunkDispatchConfig, error) { + config := execChunkDispatchConfig{ + maxChunks: req.StreamMaxPendingChunks, + maxBytes: req.StreamMaxPendingBytes, + strategy: req.StreamOverflowStrategy, + } + + if config.maxChunks <= 0 { + config.maxChunks = defaultExecStreamMaxPendingChunks + } + if config.maxBytes <= 0 { + config.maxBytes = defaultExecStreamMaxPendingBytes + } + if config.strategy == "" { + config.strategy = ExecStreamOverflowDropOldest + } + + switch config.strategy { + case ExecStreamOverflowDropOldest, ExecStreamOverflowDropNewest, ExecStreamOverflowFail: + return config, nil + default: + return execChunkDispatchConfig{}, fmt.Errorf("invalid exec stream overflow strategy %q", req.StreamOverflowStrategy) + } +} + +func firstExecError(errCh <-chan error) error { + for err := range errCh { + if err != nil { + return err + } + } + return nil +} + +func (s *StarSSH) ShellOne(cmd string) (string, error) { + result, err := s.Exec(context.Background(), ExecRequest{ + Command: cmd, + }) + if err != nil { + return "", err + } + + combined := strings.TrimSpace(result.CombinedString()) + if cmdErr := result.CommandError(); cmdErr != nil { + return combined, cmdErr + } + return combined, nil +} + +func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) { + return s.streamCommand(cmd, func(chunk string) { + fmt.Print(chunk) + }) +} + +func (s *StarSSH) ShellOneShowScreenResult(cmd string) (*ExecResult, error) { + return s.streamCommandResult(cmd, func(chunk string) { + fmt.Print(chunk) + }) +} + +func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) { + return s.streamCommand(cmd, callback) +} + +func (s *StarSSH) ShellOneToFuncResult(cmd string, callback func(string)) (*ExecResult, error) { + return s.streamCommandResult(cmd, callback) +} + +func (s *StarSSH) streamCommand(cmd string, onChunk func(string)) (string, error) { + result, err := s.streamCommandResult(cmd, onChunk) + stdoutText := strings.TrimSpace(resultStdoutString(result)) + if err != nil { + return stdoutText, err + } + return stdoutText, streamCommandLegacyError(result) +} + +func (s *StarSSH) streamCommandResult(cmd string, onChunk func(string)) (*ExecResult, error) { + result, err := s.ExecStream(context.Background(), ExecRequest{ + Command: cmd, + }, func(chunk ExecStreamChunk) { + if onChunk != nil { + onChunk(string(chunk.Data)) + } + }) + return result, err +} + +func streamCommandLegacyError(result *ExecResult) error { + if result == nil { + return nil + } + return errors.Join(result.CommandError(), result.StreamOverflowError()) +} + +func resultStdoutString(result *ExecResult) string { + if result == nil { + return "" + } + return result.StdoutString() +} + +func (s *StarSSH) Exists(filepath string) bool { + return s.remotePathProbe(legacyPathProbeExists, filepath) +} + +func (s *StarSSH) IsFile(filepath string) bool { + return s.remotePathProbe(legacyPathProbeFile, filepath) +} + +func (s *StarSSH) IsFolder(filepath string) bool { + return s.remotePathProbe(legacyPathProbeDirectory, filepath) +} + +func (s *StarSSH) GetUid() string { + return s.remoteIdentityProbe(legacyIdentityProbeUID) +} + +func (s *StarSSH) GetGid() string { + return s.remoteIdentityProbe(legacyIdentityProbeGID) +} + +func (s *StarSSH) GetUser() string { + return s.remoteIdentityProbe(legacyIdentityProbeUser) +} + +func (s *StarSSH) GetGroup() string { + return s.remoteIdentityProbe(legacyIdentityProbeGroup) +} + +type legacyPathProbeKind string + +const ( + legacyPathProbeExists legacyPathProbeKind = "exists" + legacyPathProbeFile legacyPathProbeKind = "file" + legacyPathProbeDirectory legacyPathProbeKind = "directory" +) + +type legacyIdentityProbeKind string + +const ( + legacyIdentityProbeUID legacyIdentityProbeKind = "uid" + legacyIdentityProbeGID legacyIdentityProbeKind = "gid" + legacyIdentityProbeUser legacyIdentityProbeKind = "user" + legacyIdentityProbeGroup legacyIdentityProbeKind = "group" +) + +func (s *StarSSH) remotePathProbe(kind legacyPathProbeKind, filepath string) bool { + result, ok := s.tryLegacyProbeRequests(buildLegacyPathProbeRequests(kind, filepath)) + return ok && strings.TrimSpace(result) == "1" +} + +func (s *StarSSH) remoteIdentityProbe(kind legacyIdentityProbeKind) string { + result, ok := s.tryLegacyProbeRequests(buildLegacyIdentityProbeRequests(kind)) + if !ok { + return "" + } + return strings.TrimSpace(result) +} + +func (s *StarSSH) tryLegacyProbeRequests(requests []ExecRequest) (string, bool) { + if s == nil { + return "", false + } + + for _, req := range requests { + result, err := execRequestRunner(s, context.Background(), req) + if err != nil || result == nil || result.CommandError() != nil { + continue + } + return result.StdoutString(), true + } + return "", false +} + +func buildLegacyPathProbeRequests(kind legacyPathProbeKind, filepath string) []ExecRequest { + requests := []ExecRequest{ + buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXPathProbeScript(kind, filepath))), + } + for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} { + requests = append(requests, buildLegacyRawExecRequest( + wrapPowerShellRawCommand(executable, buildLegacyPowerShellPathProbeScript(kind, filepath)), + )) + } + requests = append(requests, buildLegacyRawExecRequest(buildLegacyCMDPathProbeCommand(kind, filepath))) + return requests +} + +func buildLegacyIdentityProbeRequests(kind legacyIdentityProbeKind) []ExecRequest { + requests := []ExecRequest{ + buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXIdentityProbeScript(kind))), + } + for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} { + requests = append(requests, buildLegacyRawExecRequest( + wrapPowerShellRawCommand(executable, buildLegacyPowerShellIdentityProbeScript(kind)), + )) + } + if kind == legacyIdentityProbeUser { + requests = append(requests, buildLegacyRawExecRequest("cmd.exe /Q /D /C whoami")) + } + return requests +} + +func buildLegacyRawExecRequest(command string) ExecRequest { + return ExecRequest{ + Command: command, + ShellDialect: ExecShellDialectRaw, + } +} + +func wrapPOSIXRawCommand(script string) string { + return "sh -lc " + shellSingleQuote(script) +} + +func wrapPowerShellRawCommand(executable string, script string) string { + return executable + " -NoLogo -NoProfile -NonInteractive -Command " + powerShellSingleQuote(script) +} + +func buildLegacyPOSIXPathProbeScript(kind legacyPathProbeKind, filepath string) string { + flag := "-e" + switch kind { + case legacyPathProbeFile: + flag = "-f" + case legacyPathProbeDirectory: + flag = "-d" + } + return fmt.Sprintf("if [ %s -- %s ]; then printf '1\\n'; else printf '0\\n'; fi", flag, shellSingleQuote(filepath)) +} + +func buildLegacyPowerShellPathProbeScript(kind legacyPathProbeKind, filepath string) string { + condition := "Test-Path -LiteralPath " + powerShellSingleQuote(filepath) + switch kind { + case legacyPathProbeFile: + condition += " -PathType Leaf" + case legacyPathProbeDirectory: + condition += " -PathType Container" + } + return "if (" + condition + ") { Write-Output '1' } else { Write-Output '0' }" +} + +func buildLegacyCMDPathProbeCommand(kind legacyPathProbeKind, filepath string) string { + parts := []string{ + "setlocal DisableDelayedExpansion", + `set "STARSSH_PATH=` + cmdEscapeForSetValue(filepath) + `"`, + } + + switch kind { + case legacyPathProbeExists: + parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (echo 1) else echo 0`)) + case legacyPathProbeFile: + parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (if exist "!STARSSH_PATH!\NUL" (echo 0) else echo 1) else echo 0`)) + case legacyPathProbeDirectory: + parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!\NUL" (echo 1) else echo 0`)) + } + return strings.Join(parts, " && ") +} + +func buildLegacyPOSIXIdentityProbeScript(kind legacyIdentityProbeKind) string { + switch kind { + case legacyIdentityProbeUID: + return "id -u" + case legacyIdentityProbeGID: + return "id -g" + case legacyIdentityProbeUser: + return "id -un" + case legacyIdentityProbeGroup: + return "id -gn" + default: + return "" + } +} + +func buildLegacyPowerShellIdentityProbeScript(kind legacyIdentityProbeKind) string { + switch kind { + case legacyIdentityProbeUID: + return "[System.Security.Principal.WindowsIdentity]::GetCurrent().User.Value" + case legacyIdentityProbeGID: + return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { $group.Value }" + case legacyIdentityProbeUser: + return "$env:USERNAME" + case legacyIdentityProbeGroup: + return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { try { $group.Translate([System.Security.Principal.NTAccount]).Value } catch { $group.Value } }" + default: + return "" + } +} diff --git a/exec_legacy_test.go b/exec_legacy_test.go new file mode 100644 index 0000000..118466a --- /dev/null +++ b/exec_legacy_test.go @@ -0,0 +1,69 @@ +package starssh + +import ( + "context" + "strings" + "testing" +) + +func TestExistsFallsBackFromPOSIXToPowerShell(t *testing.T) { + oldExecRequestRunner := execRequestRunner + t.Cleanup(func() { + execRequestRunner = oldExecRequestRunner + }) + + var calls []ExecRequest + execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) { + calls = append(calls, req) + switch len(calls) { + case 1: + return &ExecResult{ExitCode: 127, Stderr: []byte("sh not found")}, nil + case 2: + return &ExecResult{Stdout: []byte("1\n")}, nil + default: + t.Fatalf("unexpected extra probe request: %+v", req) + return nil, nil + } + } + + star := &StarSSH{} + if !star.Exists(`C:\Windows\System32`) { + t.Fatal("expected helper to succeed after PowerShell fallback") + } + if len(calls) != 2 { + t.Fatalf("expected two probe attempts, got %d", len(calls)) + } + if calls[0].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(calls[0].Command, "sh -lc ") { + t.Fatalf("unexpected first probe request: %+v", calls[0]) + } + if calls[1].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(strings.ToLower(calls[1].Command), "powershell.exe ") { + t.Fatalf("unexpected second probe request: %+v", calls[1]) + } +} + +func TestGetUserFallsBackToCMDWhenPowerShellVariantsFail(t *testing.T) { + oldExecRequestRunner := execRequestRunner + t.Cleanup(func() { + execRequestRunner = oldExecRequestRunner + }) + + var calls []ExecRequest + execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) { + calls = append(calls, req) + if len(calls) < 5 { + return &ExecResult{ExitCode: 127, Stderr: []byte("command not found")}, nil + } + return &ExecResult{Stdout: []byte("HOST\\tester\r\n")}, nil + } + + star := &StarSSH{} + if got := star.GetUser(); got != `HOST\tester` { + t.Fatalf("unexpected user after fallback: %q", got) + } + if len(calls) != 5 { + t.Fatalf("expected five probe attempts, got %d", len(calls)) + } + if got := strings.ToLower(calls[4].Command); got != "cmd.exe /q /d /c whoami" { + t.Fatalf("unexpected final fallback command: %q", calls[4].Command) + } +} diff --git a/forward.go b/forward.go new file mode 100644 index 0000000..57f9e10 --- /dev/null +++ b/forward.go @@ -0,0 +1,694 @@ +package starssh + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "strings" + "sync" + + "golang.org/x/crypto/ssh" +) + +type ForwardRequest struct { + ListenAddr string + TargetAddr string + DialContext DialContextFunc +} + +type DynamicForwardRequest struct { + ListenAddr string +} + +type PortForwarder struct { + listener net.Listener + ctx context.Context + cancel context.CancelFunc + acceptDone chan struct{} + + connWG sync.WaitGroup + closeOnce sync.Once + + connMu sync.Mutex + conns map[net.Conn]struct{} + + errMu sync.Mutex + err error + + cleanupOnce sync.Once + cleanupFns []func() error +} + +var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + return client.Dial(network, address) +} + +var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { + if ctx == nil { + ctx = context.Background() + } + return LoginContext(ctx, input) +} + +func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) { + return s.DialTCPContext(context.Background(), network, address) +} + +func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) { + return s.dialTCPContext(ctx, network, address, nil) +} + +func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) { + return s.dialTCPContext(ctx, network, address, s.Close) +} + +func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + if strings.TrimSpace(network) == "" { + network = "tcp" + } + if strings.TrimSpace(address) == "" { + return nil, errors.New("forward address is empty") + } + + type dialResult struct { + conn net.Conn + err error + } + + client, err := s.requireSSHClient() + if err != nil { + return nil, err + } + + runCancel := func() {} + if onCancel != nil { + var cancelOnce sync.Once + runCancel = func() { + cancelOnce.Do(func() { + _ = onCancel() + }) + } + + cancelDone := make(chan struct{}) + defer close(cancelDone) + go func() { + select { + case <-ctx.Done(): + runCancel() + case <-cancelDone: + } + }() + } + + dialFunc := dialSSHClient + resultCh := make(chan dialResult, 1) + go func() { + conn, err := dialFunc(ctx, client, network, address) + if ctx.Err() != nil && conn != nil { + _ = conn.Close() + conn = nil + } + + select { + case resultCh <- dialResult{conn: conn, err: err}: + default: + if conn != nil { + _ = conn.Close() + } + } + }() + + select { + case result := <-resultCh: + return result.conn, result.err + case <-ctx.Done(): + runCancel() + return nil, ctx.Err() + } +} + +func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) { + if _, err := s.requireSSHClient(); err != nil { + return nil, err + } + if strings.TrimSpace(req.ListenAddr) == "" { + return nil, errors.New("local forward listen address is empty") + } + if strings.TrimSpace(req.TargetAddr) == "" { + return nil, errors.New("local forward target address is empty") + } + + listener, err := net.Listen("tcp", req.ListenAddr) + if err != nil { + return nil, err + } + + forwarder := newPortForwarder(listener) + forwarder.serve(func(ctx context.Context) (net.Conn, error) { + return s.DialTCPContext(ctx, "tcp", req.TargetAddr) + }) + return forwarder, nil +} + +func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) { + if _, err := s.requireSSHClient(); err != nil { + return nil, err + } + if strings.TrimSpace(req.ListenAddr) == "" { + return nil, errors.New("local forward listen address is empty") + } + if strings.TrimSpace(req.TargetAddr) == "" { + return nil, errors.New("local forward target address is empty") + } + + listener, err := net.Listen("tcp", req.ListenAddr) + if err != nil { + return nil, err + } + + forwardClient, err := s.newForwardDialClient(context.Background()) + if err != nil { + _ = listener.Close() + return nil, err + } + + forwarder := newPortForwarder(listener) + forwarder.addCleanup(func() error { + return normalizeAlreadyClosedError(forwardClient.Close()) + }) + forwarder.serve(func(ctx context.Context) (net.Conn, error) { + return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr) + }) + return forwarder, nil +} + +func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) { + client, err := s.requireSSHClient() + if err != nil { + return nil, err + } + if strings.TrimSpace(req.ListenAddr) == "" { + return nil, errors.New("remote forward listen address is empty") + } + if strings.TrimSpace(req.TargetAddr) == "" { + return nil, errors.New("remote forward target address is empty") + } + + listener, err := client.Listen("tcp", req.ListenAddr) + if err != nil { + return nil, err + } + + dialContext := req.DialContext + if dialContext == nil { + dialer := &net.Dialer{ + Timeout: defaultLoginTimeout, + } + dialContext = dialer.DialContext + } + + forwarder := newPortForwarder(listener) + forwarder.serve(func(ctx context.Context) (net.Conn, error) { + return dialContext(ctx, "tcp", req.TargetAddr) + }) + return forwarder, nil +} + +func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) { + if _, err := s.requireSSHClient(); err != nil { + return nil, err + } + if strings.TrimSpace(req.ListenAddr) == "" { + return nil, errors.New("dynamic forward listen address is empty") + } + + listener, err := net.Listen("tcp", req.ListenAddr) + if err != nil { + return nil, err + } + + forwarder := newPortForwarder(listener) + forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) { + return s.DialTCPContext(ctx, "tcp", targetAddr) + }) + return forwarder, nil +} + +func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) { + if _, err := s.requireSSHClient(); err != nil { + return nil, err + } + if strings.TrimSpace(req.ListenAddr) == "" { + return nil, errors.New("dynamic forward listen address is empty") + } + + listener, err := net.Listen("tcp", req.ListenAddr) + if err != nil { + return nil, err + } + + forwardClient, err := s.newForwardDialClient(context.Background()) + if err != nil { + _ = listener.Close() + return nil, err + } + + forwarder := newPortForwarder(listener) + forwarder.addCleanup(func() error { + return normalizeAlreadyClosedError(forwardClient.Close()) + }) + forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) { + return forwardClient.DialTCPContext(ctx, "tcp", targetAddr) + }) + return forwarder, nil +} + +func (f *PortForwarder) Addr() net.Addr { + if f == nil || f.listener == nil { + return nil + } + return f.listener.Addr() +} + +func (f *PortForwarder) Wait() error { + if f == nil { + return nil + } + <-f.acceptDone + f.connWG.Wait() + f.runCleanup() + return f.Err() +} + +func (f *PortForwarder) Err() error { + if f == nil { + return nil + } + f.errMu.Lock() + defer f.errMu.Unlock() + return f.err +} + +func (f *PortForwarder) Close() error { + if f == nil { + return nil + } + + var closeErr error + f.closeOnce.Do(func() { + if f.cancel != nil { + f.cancel() + } + if f.listener != nil { + closeErr = normalizeAlreadyClosedError(f.listener.Close()) + } + f.closeActiveConnections() + }) + + <-f.acceptDone + f.connWG.Wait() + f.runCleanup() + if closeErr != nil { + return closeErr + } + return f.Err() +} + +func newPortForwarder(listener net.Listener) *PortForwarder { + ctx, cancel := context.WithCancel(context.Background()) + return &PortForwarder{ + listener: listener, + ctx: ctx, + cancel: cancel, + acceptDone: make(chan struct{}), + conns: make(map[net.Conn]struct{}), + } +} + +func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) { + if s == nil { + return nil, errors.New("ssh client is nil") + } + return newDetachedForwardClient(ctx, s.LoginInfo) +} + +func (f *PortForwarder) addCleanup(fn func() error) { + if f == nil || fn == nil { + return + } + f.cleanupFns = append(f.cleanupFns, fn) +} + +func (f *PortForwarder) runCleanup() { + if f == nil { + return + } + + f.cleanupOnce.Do(func() { + for _, fn := range f.cleanupFns { + f.setError(normalizeAlreadyClosedError(fn())) + } + }) +} + +func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) { + go func() { + defer close(f.acceptDone) + + for { + conn, err := f.listener.Accept() + if err != nil { + if isClosedListenerError(err) { + return + } + f.setError(err) + return + } + + f.trackConn(conn) + f.connWG.Add(1) + go func(src net.Conn) { + defer f.connWG.Done() + defer f.untrackConn(src) + + dst, err := targetDial(f.ctx) + if err != nil { + f.setError(err) + _ = src.Close() + return + } + f.trackConn(dst) + defer f.untrackConn(dst) + + f.setError(pipeForwardConnections(src, dst)) + }(conn) + } + }() +} + +func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) { + go func() { + defer close(f.acceptDone) + + for { + conn, err := f.listener.Accept() + if err != nil { + if isClosedListenerError(err) { + return + } + f.setError(err) + return + } + + f.trackConn(conn) + f.connWG.Add(1) + go func(src net.Conn) { + defer f.connWG.Done() + defer f.untrackConn(src) + if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil { + f.setError(err) + } + }(conn) + } + }() +} + +func (f *PortForwarder) setError(err error) { + if f == nil || err == nil || f.shouldIgnoreError(err) { + return + } + + f.errMu.Lock() + defer f.errMu.Unlock() + if f.err == nil { + f.err = err + } +} + +func pipeForwardConnections(left net.Conn, right net.Conn) error { + if left == nil || right == nil { + if left != nil { + _ = left.Close() + } + if right != nil { + _ = right.Close() + } + return errors.New("forward connection endpoint is nil") + } + + defer left.Close() + defer right.Close() + + var copyWG sync.WaitGroup + errCh := make(chan error, 2) + copyWG.Add(2) + go func() { + defer copyWG.Done() + _, err := io.Copy(right, left) + errCh <- normalizeAlreadyClosedError(err) + closeWrite(right) + }() + go func() { + defer copyWG.Done() + _, err := io.Copy(left, right) + errCh <- normalizeAlreadyClosedError(err) + closeWrite(left) + }() + copyWG.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + return err + } + } + return nil +} + +func closeWrite(conn net.Conn) { + if conn == nil { + return + } + type closeWriter interface { + CloseWrite() error + } + if writer, ok := conn.(closeWriter); ok { + _ = writer.CloseWrite() + } +} + +func handleDynamicForwardConn( + ctx context.Context, + src net.Conn, + targetDial func(context.Context, string) (net.Conn, error), + trackConn func(net.Conn), + untrackConn func(net.Conn), +) error { + if src == nil { + return errors.New("dynamic forward source connection is nil") + } + defer src.Close() + + if err := negotiateSOCKS5NoAuth(src); err != nil { + return err + } + + targetAddr, replyCode, err := readSOCKS5ConnectTarget(src) + if err != nil { + _ = writeSOCKS5ServerReply(src, replyCode, nil) + return err + } + + dst, err := targetDial(ctx, targetAddr) + if err != nil { + _ = writeSOCKS5ServerReply(src, 0x01, nil) + return err + } + if trackConn != nil { + trackConn(dst) + defer untrackConn(dst) + } + + if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil { + _ = dst.Close() + return err + } + + return pipeForwardConnections(src, dst) +} + +func (f *PortForwarder) trackConn(conn net.Conn) { + if f == nil || conn == nil { + return + } + f.connMu.Lock() + defer f.connMu.Unlock() + f.conns[conn] = struct{}{} +} + +func (f *PortForwarder) untrackConn(conn net.Conn) { + if f == nil || conn == nil { + return + } + f.connMu.Lock() + defer f.connMu.Unlock() + delete(f.conns, conn) +} + +func (f *PortForwarder) closeActiveConnections() { + if f == nil { + return + } + + f.connMu.Lock() + conns := make([]net.Conn, 0, len(f.conns)) + for conn := range f.conns { + conns = append(conns, conn) + } + f.connMu.Unlock() + + for _, conn := range conns { + _ = conn.Close() + } +} + +func (f *PortForwarder) shouldIgnoreError(err error) bool { + if err == nil { + return true + } + if normalizeAlreadyClosedError(err) == nil { + return true + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + return false +} + +func negotiateSOCKS5NoAuth(conn net.Conn) error { + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + if header[0] != 0x05 { + return errors.New("invalid socks5 version") + } + + methodCount := int(header[1]) + methods := make([]byte, methodCount) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + method := byte(0xFF) + for _, candidate := range methods { + if candidate == 0x00 { + method = 0x00 + break + } + } + + if _, err := conn.Write([]byte{0x05, method}); err != nil { + return err + } + if method == 0xFF { + return errors.New("socks5 client does not support no-auth method") + } + return nil +} + +func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) { + header := make([]byte, 4) + if _, err := io.ReadFull(conn, header); err != nil { + return "", 0x01, err + } + if header[0] != 0x05 { + return "", 0x01, errors.New("invalid socks5 request version") + } + if header[1] != 0x01 { + return "", 0x07, errors.New("unsupported socks5 command") + } + + host, err := readSOCKS5RequestHost(conn, header[3]) + if err != nil { + return "", 0x08, err + } + + portBytes := make([]byte, 2) + if _, err := io.ReadFull(conn, portBytes); err != nil { + return "", 0x01, err + } + port := int(portBytes[0])<<8 | int(portBytes[1]) + return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil +} + +func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) { + switch addressType { + case 0x01: + buffer := make([]byte, 4) + if _, err := io.ReadFull(conn, buffer); err != nil { + return "", err + } + return net.IP(buffer).String(), nil + case 0x03: + size := make([]byte, 1) + if _, err := io.ReadFull(conn, size); err != nil { + return "", err + } + buffer := make([]byte, int(size[0])) + if _, err := io.ReadFull(conn, buffer); err != nil { + return "", err + } + return string(buffer), nil + case 0x04: + buffer := make([]byte, 16) + if _, err := io.ReadFull(conn, buffer); err != nil { + return "", err + } + return net.IP(buffer).String(), nil + default: + return "", errors.New("unsupported socks5 address type") + } +} + +func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error { + reply := []byte{0x05, replyCode, 0x00} + + if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil { + if ip4 := tcpAddr.IP.To4(); ip4 != nil { + reply = append(reply, 0x01) + reply = append(reply, ip4...) + } else if ip16 := tcpAddr.IP.To16(); ip16 != nil { + reply = append(reply, 0x04) + reply = append(reply, ip16...) + } + if len(reply) > 3 { + port := tcpAddr.Port + reply = append(reply, byte(port>>8), byte(port)) + _, err := conn.Write(reply) + return err + } + } + + reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) + _, err := conn.Write(reply) + return err +} + +func isClosedListenerError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return true + } + if errors.Is(err, net.ErrClosed) { + return true + } + return strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/forward_test.go b/forward_test.go new file mode 100644 index 0000000..ec38d5d --- /dev/null +++ b/forward_test.go @@ -0,0 +1,164 @@ +package starssh + +import ( + "context" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) { + oldDialSSHClient := dialSSHClient + oldNewDetachedForwardClient := newDetachedForwardClient + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + newDetachedForwardClient = oldNewDetachedForwardClient + closeSSHClient = oldCloseSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + var detachedCalls atomic.Int32 + newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { + detachedCalls.Add(1) + return nil, nil + } + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + if client != baseClient { + t.Errorf("expected existing ssh client, got %p want %p", client, baseClient) + } + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + } + + closeSSHClient = func(client sshClientRequester) error { + t.Fatal("default local forward should not close the main ssh client") + return nil + } + + forwarder, err := star.StartLocalForward(ForwardRequest{ + ListenAddr: "127.0.0.1:0", + TargetAddr: "example.internal:22", + }) + if err != nil { + t.Fatalf("start local forward: %v", err) + } + defer forwarder.Close() + + reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping")) + if string(reply) != "ping" { + t.Fatalf("unexpected forwarded reply: %q", string(reply)) + } + if detachedCalls.Load() != 0 { + t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load()) + } +} + +func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) { + oldDialSSHClient := dialSSHClient + oldNewDetachedForwardClient := newDetachedForwardClient + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + newDetachedForwardClient = oldNewDetachedForwardClient + closeSSHClient = oldCloseSSHClient + }) + + baseClient := &ssh.Client{} + detachedClient := &ssh.Client{} + star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}} + star.setTransport(baseClient, nil) + + forwardClient := &StarSSH{} + forwardClient.setTransport(detachedClient, nil) + + var detachedCalls atomic.Int32 + newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { + detachedCalls.Add(1) + return forwardClient, nil + } + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + if client != detachedClient { + t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient) + } + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + } + + var closeCalls atomic.Int32 + closeSSHClient = func(client sshClientRequester) error { + closeCalls.Add(1) + return nil + } + + forwarder, err := star.StartLocalForwardDetached(ForwardRequest{ + ListenAddr: "127.0.0.1:0", + TargetAddr: "example.internal:22", + }) + if err != nil { + t.Fatalf("start detached local forward: %v", err) + } + + reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong")) + if string(reply) != "pong" { + t.Fatalf("unexpected detached forwarded reply: %q", string(reply)) + } + + if err := forwarder.Close(); err != nil { + t.Fatalf("close detached local forward: %v", err) + } + if detachedCalls.Load() != 1 { + t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load()) + } + if closeCalls.Load() != 1 { + t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load()) + } + if got := star.snapshotSSHClient(); got != baseClient { + t.Fatal("detached local forward should not detach the main ssh client") + } + if got := forwardClient.snapshotSSHClient(); got != nil { + t.Fatal("detached local forward should close its detached ssh client") + } +} + +func echoForwardPipe(conn net.Conn) { + defer conn.Close() + buf := make([]byte, 4096) + n, err := conn.Read(buf) + if err != nil { + return + } + _, _ = conn.Write(buf[:n]) +} + +func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte { + t.Helper() + + conn, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("dial forward listener: %v", err) + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write forwarded payload: %v", err) + } + + reply := make([]byte, len(payload)) + if _, err := io.ReadFull(conn, reply); err != nil { + t.Fatalf("read forwarded reply: %v", err) + } + return reply +} diff --git a/go.mod b/go.mod index 8a5ae98..7bd62b4 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,17 @@ module b612.me/starssh -go 1.16 +go 1.20 require ( - github.com/pkg/sftp v1.13.4 - golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 + github.com/Microsoft/go-winio v0.6.1 + github.com/pkg/sftp v1.13.9 + golang.org/x/crypto v0.33.0 + golang.org/x/sys v0.30.0 +) + +require ( + github.com/kr/fs v0.1.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect ) diff --git a/go.sum b/go.sum index 6078fd9..be0baf5 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,92 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= -github.com/pkg/sftp v1.13.4 h1:Lb0RYJCmgUcBgZosfoi9Y9sbl6+LJgOIgk/2Y4YjMFg= -github.com/pkg/sftp v1.13.4/go.mod h1:LzqnAvaD5TWeNBsZpfKxSYn1MbjWwOsCIAFFJbpIsK8= +github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= +github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE= -golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hostkey.go b/hostkey.go new file mode 100644 index 0000000..77dc21e --- /dev/null +++ b/hostkey.go @@ -0,0 +1,290 @@ +package starssh + +import ( + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +var ErrKnownHostsFileRequired = errors.New("known_hosts file is required") +var ErrHostFingerprintRequired = errors.New("host key fingerprint is required") + +type HostKeyFingerprintMismatchError struct { + Expected []string + ActualSHA256 string + ActualLegacyMD5 string +} + +type AcceptNewHostKeyOptions struct { + KnownHostsFile string + HashHosts bool + IncludeRemoteAddress bool + FileMode os.FileMode +} + +func (e *HostKeyFingerprintMismatchError) Error() string { + if e == nil { + return "" + } + return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5) +} + +func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) { + trimmed := make([]string, 0, len(files)) + for _, file := range files { + file = strings.TrimSpace(file) + if file == "" { + continue + } + trimmed = append(trimmed, file) + } + if len(trimmed) == 0 { + return nil, ErrKnownHostsFileRequired + } + return knownhosts.New(trimmed...) +} + +func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) { + return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{ + KnownHostsFile: file, + IncludeRemoteAddress: true, + }) +} + +func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) { + options = normalizeAcceptNewHostKeyOptions(options) + if options.KnownHostsFile == "" { + return nil, ErrKnownHostsFileRequired + } + + state := &acceptNewHostKeyState{ + file: options.KnownHostsFile, + hashHosts: options.HashHosts, + includeRemoteAddress: options.IncludeRemoteAddress, + fileMode: options.FileMode, + } + if err := state.reload(); err != nil { + return nil, err + } + + return state.checkHostKey, nil +} + +func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) { + normalized := make([]string, 0, len(fingerprints)) + seen := make(map[string]struct{}, len(fingerprints)) + for _, raw := range fingerprints { + fingerprint, err := normalizeHostKeyFingerprint(raw) + if err != nil { + return nil, err + } + if fingerprint == "" { + continue + } + if _, exists := seen[fingerprint]; exists { + continue + } + seen[fingerprint] = struct{}{} + normalized = append(normalized, fingerprint) + } + if len(normalized) == 0 { + return nil, ErrHostFingerprintRequired + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + actualSHA256 := ssh.FingerprintSHA256(key) + actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key)) + for _, want := range normalized { + if want == actualSHA256 || want == actualMD5 { + return nil + } + } + return &HostKeyFingerprintMismatchError{ + Expected: append([]string(nil), normalized...), + ActualSHA256: actualSHA256, + ActualLegacyMD5: actualMD5, + } + }, nil +} + +type acceptNewHostKeyState struct { + file string + hashHosts bool + includeRemoteAddress bool + fileMode os.FileMode + + mu sync.Mutex + cb ssh.HostKeyCallback +} + +func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cb != nil { + err := s.cb(hostname, remote, key) + if err == nil { + return nil + } + + var keyErr *knownhosts.KeyError + if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 { + return err + } + } + + line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress) + if err != nil { + return err + } + if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil { + return err + } + if err := s.reload(); err != nil { + return err + } + if s.cb == nil { + return errors.New("known_hosts callback is nil after reload") + } + return s.cb(hostname, remote, key) +} + +func (s *acceptNewHostKeyState) reload() error { + callback, err := loadKnownHostsCallback(s.file) + if err != nil { + return err + } + s.cb = callback + return nil +} + +func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) { + _, err := os.Stat(file) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, err + } + return knownhosts.New(file) +} + +func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions { + options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile) + if options.FileMode == 0 { + options.FileMode = 0o600 + } + return options +} + +func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) { + addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress) + if len(addresses) == 0 { + return "", errors.New("no hostname or remote address available for known_hosts entry") + } + + patterns := make([]string, 0, len(addresses)) + for _, address := range addresses { + normalized := knownhosts.Normalize(address) + if hashHosts { + patterns = append(patterns, knownhosts.HashHostname(normalized)) + continue + } + patterns = append(patterns, normalized) + } + + authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))) + return strings.Join(patterns, ",") + " " + authorizedKey, nil +} + +func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string { + addresses := make([]string, 0, 2) + seen := make(map[string]struct{}, 2) + + add := func(address string) { + address = strings.TrimSpace(address) + if address == "" { + return + } + normalized := knownhosts.Normalize(address) + if _, exists := seen[normalized]; exists { + return + } + seen[normalized] = struct{}{} + addresses = append(addresses, address) + } + + add(hostname) + if includeRemoteAddress && remote != nil { + add(remote.String()) + } + + return addresses +} + +func appendKnownHostsLine(file string, line string, mode os.FileMode) error { + if strings.TrimSpace(file) == "" { + return ErrKnownHostsFileRequired + } + if strings.TrimSpace(line) == "" { + return errors.New("known_hosts line is empty") + } + + dir := filepath.Dir(file) + if dir != "." && dir != "" { + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + } + + handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode) + if err != nil { + return err + } + defer handle.Close() + + if _, err := handle.WriteString(line + "\n"); err != nil { + return err + } + return handle.Chmod(mode) +} + +func normalizeHostKeyFingerprint(raw string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", nil + } + + if strings.HasPrefix(strings.ToUpper(value), "SHA256:") { + suffix := strings.TrimSpace(value[len("SHA256:"):]) + if suffix == "" { + return "", ErrHostFingerprintRequired + } + return "SHA256:" + suffix, nil + } + if strings.HasPrefix(strings.ToUpper(value), "MD5:") { + suffix := strings.TrimSpace(value[len("MD5:"):]) + if suffix == "" { + return "", ErrHostFingerprintRequired + } + return normalizeMD5Fingerprint("MD5:" + suffix), nil + } + if strings.Count(value, ":") >= 2 { + return normalizeMD5Fingerprint("MD5:" + value), nil + } + return "SHA256:" + value, nil +} + +func normalizeMD5Fingerprint(value string) string { + if !strings.HasPrefix(strings.ToUpper(value), "MD5:") { + return "MD5:" + strings.ToLower(value) + } + return "MD5:" + strings.ToLower(value[len("MD5:"):]) +} diff --git a/keepalive.go b/keepalive.go new file mode 100644 index 0000000..d50d8d0 --- /dev/null +++ b/keepalive.go @@ -0,0 +1,135 @@ +package starssh + +import ( + "context" + "sync" + "time" +) + +var sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + return err +} + +func (s *StarSSH) Ping() error { + return s.PingContext(context.Background()) +} + +func (s *StarSSH) PingContext(ctx context.Context) error { + return s.pingContext(ctx, nil) +} + +func (s *StarSSH) PingContextCloseOnCancel(ctx context.Context) error { + return s.pingContext(ctx, s.Close) +} + +func (s *StarSSH) pingContext(ctx context.Context, onCancel func() error) error { + if ctx == nil { + ctx = context.Background() + } + + client, err := s.requireSSHClient() + if err != nil { + return err + } + + runCancel := func() {} + if onCancel != nil { + var cancelOnce sync.Once + runCancel = func() { + cancelOnce.Do(func() { + _ = onCancel() + }) + } + + cancelDone := make(chan struct{}) + defer close(cancelDone) + go func() { + select { + case <-ctx.Done(): + runCancel() + case <-cancelDone: + } + }() + } + + requestFunc := sendKeepAliveRequest + pingErr := make(chan error, 1) + go func() { + err := requestFunc(ctx, client) + select { + case pingErr <- err: + default: + } + }() + + select { + case err := <-pingErr: + return err + case <-ctx.Done(): + runCancel() + return ctx.Err() + } +} + +func (s *StarSSH) startAutoKeepAlive() { + if s == nil || s.snapshotSSHClient() == nil { + return + } + + interval := s.LoginInfo.KeepAliveInterval + if interval <= 0 { + return + } + timeout := s.LoginInfo.KeepAliveTimeout + if timeout <= 0 { + timeout = defaultKeepAliveTimeout + } + + stop := make(chan struct{}) + done := make(chan struct{}) + + s.keepaliveMu.Lock() + if s.keepaliveStop != nil { + s.keepaliveMu.Unlock() + return + } + s.keepaliveStop = stop + s.keepaliveDone = done + s.keepaliveMu.Unlock() + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + defer close(done) + + for { + select { + case <-stop: + return + case <-ticker.C: + pingCtx, cancel := context.WithTimeout(context.Background(), timeout) + err := s.pingContext(pingCtx, nil) + cancel() + if err != nil { + s.closeFromKeepAlive() + return + } + } + } + }() +} + +func (s *StarSSH) stopAutoKeepAlive() { + stop, done := s.takeKeepaliveHandles() + if stop != nil { + close(stop) + } + if done != nil { + <-done + } +} + +func (s *StarSSH) closeFromKeepAlive() { + _ = s.closeTransport(false) +} diff --git a/keepalive_test.go b/keepalive_test.go new file mode 100644 index 0000000..a5f7b42 --- /dev/null +++ b/keepalive_test.go @@ -0,0 +1,53 @@ +package starssh + +import ( + "context" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestAutoKeepAliveTimeoutDoesNotDeadlock(t *testing.T) { + oldSendKeepAliveRequest := sendKeepAliveRequest + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + sendKeepAliveRequest = oldSendKeepAliveRequest + closeSSHClient = oldCloseSSHClient + }) + + sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { + <-ctx.Done() + return ctx.Err() + } + closeSSHClient = func(client sshClientRequester) error { + return nil + } + + client := &ssh.Client{} + star := &StarSSH{ + LoginInfo: LoginInput{ + KeepAliveInterval: 10 * time.Millisecond, + KeepAliveTimeout: 20 * time.Millisecond, + }, + } + star.setTransport(client, nil) + star.startAutoKeepAlive() + + star.keepaliveMu.Lock() + done := star.keepaliveDone + star.keepaliveMu.Unlock() + if done == nil { + t.Fatal("keepalive did not start") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("keepalive goroutine did not exit after keepalive timeout") + } + + if client := star.snapshotSSHClient(); client != nil { + t.Fatal("ssh client should be closed after keepalive timeout") + } +} diff --git a/login.go b/login.go new file mode 100644 index 0000000..948d62e --- /dev/null +++ b/login.go @@ -0,0 +1,362 @@ +package starssh + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net" + "os" + "strings" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") +var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") + +var defaultAuthOrder = []AuthMethodKind{ + AuthMethodSSHAgent, + AuthMethodPrivateKey, + AuthMethodPassword, + AuthMethodKeyboardInteractive, +} + +func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil +} + +func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) { + return loginWithContext(ctx, info) +} + +func Login(info LoginInput) (*StarSSH, error) { + return LoginContext(context.Background(), info) +} + +func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { + info = normalizeLoginInput(info) + if info.HostKeyCallback == nil { + return nil, ErrHostKeyCallbackRequired + } + + loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout) + defer cancel() + + sshInfo := &StarSSH{ + LoginInfo: info, + } + + auth, authCleanup, err := buildAuthMethods(info) + if err != nil { + return nil, err + } + if authCleanup != nil { + defer authCleanup() + } + + hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + sshInfo.PublicKey = key + sshInfo.RemoteAddr = remote + sshInfo.Hostname = hostname + + return info.HostKeyCallback(hostname, remote, key) + } + + bannerCallback := func(banner string) error { + sshInfo.Banner = banner + if info.BannerCallback != nil { + return info.BannerCallback(banner) + } + return nil + } + + clientConfig := &ssh.ClientConfig{ + User: info.User, + Auth: auth, + Timeout: info.Timeout, + HostKeyCallback: hostKeyCallback, + BannerCallback: bannerCallback, + } + if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 { + clientConfig.Config = ssh.Config{ + Ciphers: info.Ciphers, + MACs: info.MACs, + KeyExchanges: info.KeyExchanges, + } + } + + targetAddr := joinHostPort(info.Addr, info.Port) + rawConn, upstream, err := dialTargetConn(loginCtx, info) + if err != nil { + return sshInfo, err + } + restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout) + defer restoreDeadline() + + clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) + if err != nil { + _ = rawConn.Close() + if upstream != nil { + _ = upstream.Close() + } + return sshInfo, err + } + client := ssh.NewClient(clientConn, chans, reqs) + + sshInfo.setTransport(client, upstream) + if sshInfo.PublicKey != nil { + sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal()) + } + sshInfo.startAutoKeepAlive() + + return sshInfo, nil +} + +func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + if timeout <= 0 { + return ctx, func() {} + } + return context.WithTimeout(ctx, timeout) +} + +func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) { + info := LoginInput{ + Addr: host, + Port: port, + Timeout: timeout, + User: user, + HostKeyCallback: DefaultAllowHostKeyCallback, + } + + if prikeyPath != "" { + prikey, err := os.ReadFile(prikeyPath) + if err != nil { + return nil, err + } + info.Prikey = string(prikey) + if passwd != "" { + info.PrikeyPwd = passwd + } + } else { + info.Password = passwd + } + + return Login(info) +} + +func normalizeLoginInput(info LoginInput) LoginInput { + if info.Port <= 0 { + info.Port = defaultSSHPort + } + if info.Timeout <= 0 { + info.Timeout = defaultLoginTimeout + } + return info +} + +func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { + order, err := normalizeAuthOrder(info.AuthOrder) + if err != nil { + return nil, nil, err + } + + auth := make([]ssh.AuthMethod, 0, len(order)) + var agentErr error + var cleanupFuncs []func() + + for _, methodKind := range order { + switch methodKind { + case AuthMethodPrivateKey: + method, err := buildPrivateKeyAuthMethod(info) + if err != nil { + return nil, nil, err + } + if method != nil { + auth = append(auth, method) + } + case AuthMethodPassword: + method := buildPasswordAuthMethod(info.Password, info.PasswordCallback) + if method != nil { + auth = append(auth, method) + } + case AuthMethodKeyboardInteractive: + method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback) + if method != nil { + auth = append(auth, method) + } + case AuthMethodSSHAgent: + if info.DisableSSHAgent { + continue + } + agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout) + if err != nil { + agentErr = err + continue + } + if agentMethod != nil { + auth = append(auth, agentMethod) + } + if cleanup != nil { + cleanupFuncs = append(cleanupFuncs, cleanup) + } + } + } + + if len(auth) == 0 { + if agentErr != nil { + return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr) + } + return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required") + } + + return auth, composeCleanup(cleanupFuncs...), nil +} + +func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) { + if len(order) == 0 { + return append([]AuthMethodKind(nil), defaultAuthOrder...), nil + } + + normalized := make([]AuthMethodKind, 0, len(order)) + seen := make(map[AuthMethodKind]struct{}, len(order)) + for _, raw := range order { + kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw)))) + if kind == "" { + return nil, errors.New("auth order contains an empty auth method") + } + if !isSupportedAuthMethodKind(kind) { + return nil, fmt.Errorf("unsupported auth method %q", raw) + } + if _, exists := seen[kind]; exists { + continue + } + seen[kind] = struct{}{} + normalized = append(normalized, kind) + } + + if len(normalized) == 0 { + return nil, errors.New("auth order is empty") + } + return normalized, nil +} + +func isSupportedAuthMethodKind(kind AuthMethodKind) bool { + switch kind { + case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent: + return true + default: + return false + } +} + +func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) { + if strings.TrimSpace(info.Prikey) == "" { + return nil, nil + } + + pemBytes := []byte(info.Prikey) + if info.PrikeyPwd == "" { + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return nil, err + } + return ssh.PublicKeys(signer), nil + } + + signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) + if err != nil { + return nil, err + } + return ssh.PublicKeys(signer), nil +} + +func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod { + if password != "" { + return ssh.Password(password) + } + if callback != nil { + return ssh.PasswordCallback(callback) + } + return nil +} + +func buildKeyboardInteractiveAuthMethod( + password string, + passwordCallback func() (string, error), + challenge ssh.KeyboardInteractiveChallenge, +) ssh.AuthMethod { + if challenge != nil { + return ssh.KeyboardInteractive(challenge) + } + if password == "" && passwordCallback == nil { + return nil + } + + keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) { + if len(questions) == 0 { + return []string{}, nil + } + + answer := password + if answer == "" { + var err error + answer, err = passwordCallback() + if err != nil { + return nil, err + } + } + + answers := make([]string, len(questions)) + for i := range questions { + answers[i] = answer + } + return answers, nil + } + return ssh.KeyboardInteractive(keyboardInteractiveChallenge) +} + +func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) { + conn, err := dialSSHAgent(timeout) + if err != nil { + if errors.Is(err, errSSHAgentUnavailable) { + return nil, nil, nil + } + return nil, nil, err + } + if conn == nil { + return nil, nil, nil + } + + signers, err := agent.NewClient(conn).Signers() + if err != nil { + _ = conn.Close() + return nil, nil, err + } + if len(signers) == 0 { + _ = conn.Close() + return nil, nil, errors.New("ssh-agent has no loaded keys") + } + + return ssh.PublicKeys(signers...), func() { + _ = conn.Close() + }, nil +} + +func composeCleanup(funcs ...func()) func() { + if len(funcs) == 0 { + return nil + } + return func() { + for i := len(funcs) - 1; i >= 0; i-- { + if funcs[i] != nil { + funcs[i]() + } + } + } +} diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..e0f91c8 --- /dev/null +++ b/pool.go @@ -0,0 +1,466 @@ +package starssh + +import ( + "context" + "errors" + "sync" + "time" +) + +const defaultExecPoolMaxOpenConns = 4 + +var ErrExecPoolClosed = errors.New("exec pool is closed") + +type ExecPoolConfig struct { + Login LoginInput + MaxOpenConns int + MaxIdleConns int + MaxIdleTime time.Duration + DisableHealthCheck bool + HealthCheckTimeout time.Duration +} + +type ExecPoolStats struct { + MaxOpenConns int + MaxIdleConns int + MaxIdleTime time.Duration + OpenConns int + IdleConns int + InUseConns int +} + +type ExecPool struct { + loginInfo LoginInput + maxOpen int + maxIdle int + maxIdleTime time.Duration + + idle chan *pooledClient + done chan struct{} + closeOnce sync.Once + healthCheckOnAcquire bool + healthCheckTimeout time.Duration + + mu sync.Mutex + open int + closed bool +} + +type pooledClient struct { + client *StarSSH + idleAt time.Time +} + +func NewExecPool(config ExecPoolConfig) *ExecPool { + maxOpen := config.MaxOpenConns + if maxOpen <= 0 { + maxOpen = defaultExecPoolMaxOpenConns + } + + maxIdle := config.MaxIdleConns + if maxIdle <= 0 || maxIdle > maxOpen { + maxIdle = maxOpen + } + + return &ExecPool{ + loginInfo: config.Login, + maxOpen: maxOpen, + maxIdle: maxIdle, + maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime), + idle: make(chan *pooledClient, maxIdle), + done: make(chan struct{}), + healthCheckOnAcquire: !config.DisableHealthCheck, + healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout), + } +} + +func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { + client, err := p.Acquire(ctx) + if err != nil { + return nil, err + } + + result, execErr := client.Exec(ctx, req) + if execErr != nil { + p.Discard(client) + return result, execErr + } + if releaseErr := p.Release(client); releaseErr != nil { + return result, releaseErr + } + return result, nil +} + +func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) { + return p.Exec(ctx, ExecRequest{ + Command: command, + }) +} + +func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { + client, err := p.Acquire(ctx) + if err != nil { + return nil, err + } + + result, execErr := client.ExecStream(ctx, req, onChunk) + if execErr != nil { + p.Discard(client) + return result, execErr + } + if releaseErr := p.Release(client); releaseErr != nil { + return result, releaseErr + } + return result, nil +} + +func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error { + if p == nil { + return errors.New("exec pool is nil") + } + if ctx == nil { + ctx = context.Background() + } + + targetIdle = p.normalizeWarmUpTarget(targetIdle) + if targetIdle == 0 { + return nil + } + + for { + if err := ctx.Err(); err != nil { + return err + } + + idleCount, create, err := p.tryWarmUp(targetIdle) + if err != nil { + return err + } + if idleCount >= targetIdle || !create { + return nil + } + + conn, err := LoginContext(ctx, p.loginInfo) + if err != nil { + p.releaseSlot() + return err + } + if err := p.Release(conn); err != nil { + return err + } + } +} + +func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) { + if p == nil { + return nil, errors.New("exec pool is nil") + } + if ctx == nil { + ctx = context.Background() + } + + for { + idleClient, create, err := p.tryAcquire() + if err != nil { + return nil, err + } + if idleClient != nil { + client, ok := p.takeIdleClient(ctx, idleClient) + if !ok { + continue + } + return client, nil + } + if create { + conn, err := LoginContext(ctx, p.loginInfo) + if err != nil { + p.releaseSlot() + return nil, err + } + return conn, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-p.done: + return nil, ErrExecPoolClosed + case idleClient = <-p.idle: + if idleClient == nil { + continue + } + client, ok := p.takeIdleClient(ctx, idleClient) + if !ok { + continue + } + return client, nil + } + } +} + +func (p *ExecPool) Release(client *StarSSH) error { + if p == nil { + return errors.New("exec pool is nil") + } + if client == nil { + p.releaseSlot() + return nil + } + + p.mu.Lock() + if p.closed { + p.mu.Unlock() + p.closeClient(client) + return nil + } + + select { + case p.idle <- &pooledClient{ + client: client, + idleAt: time.Now(), + }: + p.mu.Unlock() + return nil + default: + p.mu.Unlock() + p.closeClient(client) + return nil + } +} + +func (p *ExecPool) Discard(client *StarSSH) { + if p == nil { + return + } + if client == nil { + p.releaseSlot() + return + } + p.closeClient(client) +} + +func (p *ExecPool) Stats() ExecPoolStats { + if p == nil { + return ExecPoolStats{} + } + + p.mu.Lock() + defer p.mu.Unlock() + + idleCount := len(p.idle) + openCount := p.open + inUseCount := openCount - idleCount + if inUseCount < 0 { + inUseCount = 0 + } + + return ExecPoolStats{ + MaxOpenConns: p.maxOpen, + MaxIdleConns: p.maxIdle, + MaxIdleTime: p.maxIdleTime, + OpenConns: openCount, + IdleConns: idleCount, + InUseConns: inUseCount, + } +} + +func (p *ExecPool) Close() error { + if p == nil { + return nil + } + + var closeErr error + p.closeOnce.Do(func() { + p.mu.Lock() + p.closed = true + idleClients := p.drainIdleLocked() + p.mu.Unlock() + close(p.done) + + for _, client := range idleClients { + if err := client.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + }) + return closeErr +} + +func (p *ExecPool) CloseIdle() error { + if p == nil { + return nil + } + + p.mu.Lock() + idleClients := p.drainIdleLocked() + p.mu.Unlock() + + var closeErr error + for _, client := range idleClients { + if err := client.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + return closeErr +} + +func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, false, ErrExecPoolClosed + } + + select { + case client := <-p.idle: + return client, false, nil + default: + } + + if p.open < p.maxOpen { + p.open++ + return nil, true, nil + } + return nil, false, nil +} + +func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return 0, false, ErrExecPoolClosed + } + + idleCount := len(p.idle) + if idleCount >= targetIdle { + return idleCount, false, nil + } + if p.open >= p.maxOpen { + return idleCount, false, nil + } + + p.open++ + return idleCount, true, nil +} + +func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int { + if p == nil || p.maxIdle <= 0 { + return 0 + } + if targetIdle <= 0 { + return p.maxIdle + } + if targetIdle > p.maxIdle { + return p.maxIdle + } + return targetIdle +} + +func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) { + if idleClient == nil { + return nil, false + } + if idleClient.client == nil { + p.releaseSlot() + return nil, false + } + if p.isIdleExpired(idleClient) { + p.closePooledClient(idleClient) + return nil, false + } + if err := p.healthCheckClient(ctx, idleClient.client); err != nil { + p.closePooledClient(idleClient) + return nil, false + } + return idleClient.client, true +} + +func (p *ExecPool) isIdleExpired(client *pooledClient) bool { + if p == nil || client == nil || client.client == nil { + return false + } + if p.maxIdleTime <= 0 || client.idleAt.IsZero() { + return false + } + return time.Since(client.idleAt) >= p.maxIdleTime +} + +func (p *ExecPool) drainIdleLocked() []*StarSSH { + clients := make([]*StarSSH, 0, len(p.idle)) + for { + select { + case idleClient := <-p.idle: + if p.open > 0 { + p.open-- + } + if idleClient == nil || idleClient.client == nil { + continue + } + clients = append(clients, idleClient.client) + default: + return clients + } + } +} + +func (p *ExecPool) releaseSlot() { + p.mu.Lock() + defer p.mu.Unlock() + if p.open > 0 { + p.open-- + } +} + +func (p *ExecPool) closeClient(client *StarSSH) { + if client != nil { + _ = client.Close() + } + p.releaseSlot() +} + +func (p *ExecPool) closePooledClient(client *pooledClient) { + if client == nil { + return + } + if client.client == nil { + p.releaseSlot() + return + } + p.closeClient(client.client) +} + +func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error { + if client == nil { + return errors.New("ssh client is nil") + } + if !p.healthCheckOnAcquire { + return nil + } + + if ctx == nil { + ctx = context.Background() + } + timeout := p.healthCheckTimeout + if timeout > 0 { + healthCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ctx = healthCtx + } + return client.PingContext(ctx) +} + +func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration { + if timeout <= 0 { + return defaultKeepAliveTimeout + } + return timeout +} + +func normalizeMaxIdleTime(timeout time.Duration) time.Duration { + if timeout <= 0 { + return 0 + } + return timeout +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..769b6bf --- /dev/null +++ b/session.go @@ -0,0 +1,121 @@ +package starssh + +import ( + "errors" + "io" + "net" + "strings" + + "golang.org/x/crypto/ssh" +) + +func (s *StarSSH) Close() error { + return s.closeTransport(true) +} + +func (s *StarSSH) NewSession() (*ssh.Session, error) { + return s.NewPTYSession(nil) +} + +func (s *StarSSH) NewExecSession() (*ssh.Session, error) { + client, err := s.requireSSHClient() + if err != nil { + return nil, err + } + return NewExecSession(client) +} + +func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) { + client, err := s.requireSSHClient() + if err != nil { + return nil, err + } + return NewPTYSession(client, config) +} + +func NewTransferSession(client *ssh.Client) (*ssh.Session, error) { + return NewExecSession(client) +} + +func NewExecSession(client *ssh.Client) (*ssh.Session, error) { + if client == nil { + return nil, errors.New("ssh client is nil") + } + return client.NewSession() +} + +func NewSession(client *ssh.Client) (*ssh.Session, error) { + return NewPTYSession(client, nil) +} + +func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, error) { + if client == nil { + return nil, errors.New("ssh client is nil") + } + + session, err := client.NewSession() + if err != nil { + return nil, err + } + + cfg := normalizeTerminalConfig(config) + if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil { + _ = session.Close() + return nil, err + } + + return session, nil +} + +func normalizeTerminalConfig(config *TerminalConfig) TerminalConfig { + cfg := TerminalConfig{ + Term: defaultPTYTerm, + Rows: defaultPTYRows, + Columns: defaultPTYColumns, + Modes: ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + }, + } + if config == nil { + return cfg + } + + if strings.TrimSpace(config.Term) != "" { + cfg.Term = config.Term + } + if config.Rows > 0 { + cfg.Rows = config.Rows + } + if config.Columns > 0 { + cfg.Columns = config.Columns + } + if len(config.Modes) > 0 { + cfg.Modes = config.Modes + } + return cfg +} + +func normalizeAlreadyClosedError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, net.ErrClosed) { + return nil + } + if errors.Is(err, io.EOF) { + return nil + } + if strings.Contains(err.Error(), "use of closed network connection") { + return nil + } + return err +} + +func closeUpstream(upstream *StarSSH) error { + if upstream == nil { + return nil + } + return upstream.Close() +} diff --git a/sftp.go b/sftp.go index a886d24..3bd40a3 100644 --- a/sftp.go +++ b/sftp.go @@ -2,191 +2,1564 @@ package starssh import ( "bytes" - "github.com/pkg/sftp" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" "io" + "net" "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/pkg/sftp" ) -func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) { - return sftp.NewClient(star.Client) +const ( + defaultSFTPRetryCount = 2 + defaultSFTPRetryInitialBackoff = 250 * time.Millisecond + defaultSFTPTempSuffix = ".starssh.tmp" +) + +const preservedFileModeBits os.FileMode = os.ModePerm | os.ModeSetuid | os.ModeSetgid | os.ModeSticky + +type SFTPTransferOptions struct { + BufferSize int + Progress func(float64) + RetryCount *int + RetryInitialBackoff *time.Duration + AtomicUpload *bool + AtomicDownload *bool + VerifySize *bool + VerifyChecksum *bool + TempSuffix string } -func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error { - sftpC, err := star.CreateSftpClient() - if err != nil { - return err - } - defer sftpC.Close() - return SftpTransferOut(localFilePath, remotePath, sftpC) +type resolvedSFTPTransferOptions struct { + BufferSize int + Progress func(float64) + RetryCount int + RetryInitialBackoff time.Duration + AtomicUpload bool + AtomicDownload bool + VerifySize bool + VerifyChecksum bool + TempSuffix string } -func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client) error { - srcFile, err := os.Open(localFilePath) - if err != nil { - return err - } - defer srcFile.Close() - // var remoteFileName = filepath.Base(localFilePath) - dstFile, err := sftpClient.Create(remotePath) - if err != nil { - return err - } - defer dstFile.Close() - for { - buf := make([]byte, 1048576) - n, err := srcFile.Read(buf) - dstFile.Write(buf[:n]) - if err == io.EOF { - break - } - } - return nil +type SFTPErrorCategory string + +const ( + SFTPErrorRetryable SFTPErrorCategory = "retryable" + SFTPErrorPermanent SFTPErrorCategory = "permanent" +) + +type SFTPTransferError struct { + Operation string + LocalPath string + RemotePath string + Attempt int + Category SFTPErrorCategory + Err error } -func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) error { - sftpC, err := star.CreateSftpClient() - if err != nil { - return err +func (e *SFTPTransferError) Error() string { + if e == nil { + return "" } - defer sftpC.Close() - return SftpTransferOutByte(localData, remotePath, sftpC) + return fmt.Sprintf("%s failed [%s] (attempt=%d, local=%q, remote=%q): %v", e.Operation, e.Category, e.Attempt, e.LocalPath, e.RemotePath, e.Err) } -func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.Client) error { - dstFile, err := sftpClient.Create(remotePath) - if err != nil { - return err +func (e *SFTPTransferError) Unwrap() error { + if e == nil { + return nil } - defer dstFile.Close() - _, err = dstFile.Write(localData) - return err + return e.Err } -func (star *StarSSH) SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64)) error { - sftpC, err := star.CreateSftpClient() - if err != nil { - return err - } - defer sftpC.Close() - return SftpTransferOutFunc(localFilePath, remotePath, bufcap, rtefunc, sftpC) +type FS interface { + Stat(context.Context, string) (os.FileInfo, error) + ReadDir(context.Context, string) ([]os.FileInfo, error) + ReadFile(context.Context, string, *SFTPTransferOptions) ([]byte, error) + WriteFile(context.Context, string, []byte, *SFTPTransferOptions) error + MkdirAll(context.Context, string) error + Remove(context.Context, string) error + RemoveAll(context.Context, string) error + Rename(context.Context, string, string) error } -func SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { - num := 0 - srcFile, err := os.Open(localFilePath) - if err != nil { - return err +type SFTPFileSystem struct { + star *StarSSH +} + +type atomicReplaceTarget struct { + exists bool + mode os.FileMode +} + +var ( + sftpCopyWithProgressFunc = copyWithProgressContext + sftpVerifyRemoteSizeFunc = verifyRemoteSize + sftpVerifyLocalSizeFunc = verifyLocalSize + sftpLocalFileSHA256Func = localFileSHA256 + sftpRemoteFileSHA256Func = remoteFileSHA256 +) + +func DefaultSFTPTransferOptions() SFTPTransferOptions { + return SFTPTransferOptions{ + BufferSize: defaultTransferBufferSize, + RetryCount: SFTPInt(defaultSFTPRetryCount), + RetryInitialBackoff: SFTPDuration(defaultSFTPRetryInitialBackoff), + AtomicUpload: SFTPBool(true), + AtomicDownload: SFTPBool(true), + VerifySize: SFTPBool(true), + VerifyChecksum: SFTPBool(false), + TempSuffix: defaultSFTPTempSuffix, } - defer srcFile.Close() - stat, _ := os.Stat(localFilePath) - filebig := float64(stat.Size()) - //var remoteFileName = filepath.Base(localFilePath) - dstFile, err := sftpClient.Create(remotePath) - if err != nil { - return err +} + +func SFTPBool(value bool) *bool { + return &value +} + +func SFTPInt(value int) *int { + return &value +} + +func SFTPDuration(value time.Duration) *time.Duration { + return &value +} + +func (star *StarSSH) FS() *SFTPFileSystem { + return &SFTPFileSystem{star: star} +} + +func (star *StarSSH) Stat(remotePath string) (os.FileInfo, error) { + return star.StatContext(context.Background(), remotePath) +} + +func (star *StarSSH) StatContext(ctx context.Context, remotePath string) (os.FileInfo, error) { + return star.FS().Stat(ctx, remotePath) +} + +func (star *StarSSH) ReadDir(remotePath string) ([]os.FileInfo, error) { + return star.ReadDirContext(context.Background(), remotePath) +} + +func (star *StarSSH) ReadDirContext(ctx context.Context, remotePath string) ([]os.FileInfo, error) { + return star.FS().ReadDir(ctx, remotePath) +} + +func (star *StarSSH) ReadFile(remotePath string) ([]byte, error) { + return star.ReadFileContext(context.Background(), remotePath, nil) +} + +func (star *StarSSH) ReadFileContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { + return star.FS().ReadFile(ctx, remotePath, options) +} + +func (star *StarSSH) WriteFile(remotePath string, data []byte) error { + return star.WriteFileContext(context.Background(), remotePath, data, nil) +} + +func (star *StarSSH) WriteFileContext(ctx context.Context, remotePath string, data []byte, options *SFTPTransferOptions) error { + return star.FS().WriteFile(ctx, remotePath, data, options) +} + +func (star *StarSSH) MkdirAll(remotePath string) error { + return star.MkdirAllContext(context.Background(), remotePath) +} + +func (star *StarSSH) MkdirAllContext(ctx context.Context, remotePath string) error { + return star.FS().MkdirAll(ctx, remotePath) +} + +func (star *StarSSH) Remove(remotePath string) error { + return star.RemoveContext(context.Background(), remotePath) +} + +func (star *StarSSH) RemoveContext(ctx context.Context, remotePath string) error { + return star.FS().Remove(ctx, remotePath) +} + +func (star *StarSSH) RemoveAll(remotePath string) error { + return star.RemoveAllContext(context.Background(), remotePath) +} + +func (star *StarSSH) RemoveAllContext(ctx context.Context, remotePath string) error { + return star.FS().RemoveAll(ctx, remotePath) +} + +func (star *StarSSH) Rename(oldPath string, newPath string) error { + return star.RenameContext(context.Background(), oldPath, newPath) +} + +func (star *StarSSH) RenameContext(ctx context.Context, oldPath string, newPath string) error { + return star.FS().Rename(ctx, oldPath, newPath) +} + +func (fs *SFTPFileSystem) Stat(ctx context.Context, remotePath string) (os.FileInfo, error) { + if fs == nil || fs.star == nil { + return nil, errors.New("sftp filesystem is nil") } - defer dstFile.Close() - for { - buf := make([]byte, bufcap) - n, err := srcFile.Read(buf) - num += n - go rtefunc(float64(num) / filebig * 100) - dstFile.Write(buf[:n]) - if err == io.EOF { - break - } + if err := validateRemotePath(remotePath); err != nil { + return nil, err + } + + var info os.FileInfo + err := fs.star.runSFTPClientOperation(ctx, "sftp_stat", remotePath, func(client *sftp.Client) error { + out, err := client.Stat(remotePath) if err != nil { return err } - } - return nil -} - -func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) { - sftpC, err := star.CreateSftpClient() + info = out + return nil + }) if err != nil { return nil, err } - defer sftpC.Close() - return SftpTransferInByte(remotePath, sftpC) + return info, nil } -func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, error) { - dstFile, err := sftpClient.Open(remotePath) - if err != nil { - return []byte{}, err +func (fs *SFTPFileSystem) ReadDir(ctx context.Context, remotePath string) ([]os.FileInfo, error) { + if fs == nil || fs.star == nil { + return nil, errors.New("sftp filesystem is nil") } - defer dstFile.Close() - buf := new(bytes.Buffer) - _, err = dstFile.WriteTo(buf) - return buf.Bytes(), err + if err := validateRemotePath(remotePath); err != nil { + return nil, err + } + + var entries []os.FileInfo + err := fs.star.runSFTPClientOperation(ctx, "sftp_readdir", remotePath, func(client *sftp.Client) error { + out, err := client.ReadDir(remotePath) + if err != nil { + return err + } + entries = out + return nil + }) + if err != nil { + return nil, err + } + return entries, nil } -func (star *StarSSH) SftpTransferIn(src, dst string) error { - sftpC, err := star.CreateSftpClient() +func (fs *SFTPFileSystem) ReadFile(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { + if fs == nil || fs.star == nil { + return nil, errors.New("sftp filesystem is nil") + } + return fs.star.SftpTransferInByteContext(ctx, remotePath, options) +} + +func (fs *SFTPFileSystem) WriteFile(ctx context.Context, remotePath string, data []byte, options *SFTPTransferOptions) error { + if fs == nil || fs.star == nil { + return errors.New("sftp filesystem is nil") + } + return fs.star.SftpTransferOutByteContext(ctx, data, remotePath, options) +} + +func (fs *SFTPFileSystem) MkdirAll(ctx context.Context, remotePath string) error { + if fs == nil || fs.star == nil { + return errors.New("sftp filesystem is nil") + } + if err := validateRemotePath(remotePath); err != nil { + return err + } + + return fs.star.runSFTPClientOperation(ctx, "sftp_mkdir_all", remotePath, func(client *sftp.Client) error { + return client.MkdirAll(remotePath) + }) +} + +func (fs *SFTPFileSystem) Remove(ctx context.Context, remotePath string) error { + if fs == nil || fs.star == nil { + return errors.New("sftp filesystem is nil") + } + if err := validateRemotePath(remotePath); err != nil { + return err + } + + return fs.star.runSFTPClientOperationNoRetry(ctx, func(client *sftp.Client) error { + return removeRemotePath(client, remotePath) + }) +} + +func (fs *SFTPFileSystem) RemoveAll(ctx context.Context, remotePath string) error { + if fs == nil || fs.star == nil { + return errors.New("sftp filesystem is nil") + } + if err := validateRemotePath(remotePath); err != nil { + return err + } + + return fs.star.runSFTPClientOperation(ctx, "sftp_remove_all", remotePath, func(client *sftp.Client) error { + return removeRemoteAll(ctx, client, remotePath) + }) +} + +func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath string) error { + if fs == nil || fs.star == nil { + return errors.New("sftp filesystem is nil") + } + if err := validateRemotePath(oldPath); err != nil { + return err + } + if err := validateRemotePath(newPath); err != nil { + return err + } + + return fs.star.runSFTPClientOperationNoRetry(ctx, func(client *sftp.Client) error { + return renameRemoteAtomic(client, oldPath, newPath) + }) +} + +func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions { + opts := DefaultSFTPTransferOptions() + if options == nil { + return resolvedSFTPTransferOptions{ + BufferSize: opts.BufferSize, + Progress: opts.Progress, + RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), + RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff), + AtomicUpload: derefSFTPBool(opts.AtomicUpload, true), + AtomicDownload: derefSFTPBool(opts.AtomicDownload, true), + VerifySize: derefSFTPBool(opts.VerifySize, true), + VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), + TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), + } + } + + if options.BufferSize > 0 { + opts.BufferSize = options.BufferSize + } + if options.Progress != nil { + opts.Progress = options.Progress + } + if options.RetryCount != nil { + opts.RetryCount = options.RetryCount + } + if options.RetryInitialBackoff != nil { + opts.RetryInitialBackoff = options.RetryInitialBackoff + } + if options.AtomicUpload != nil { + opts.AtomicUpload = options.AtomicUpload + } + if options.AtomicDownload != nil { + opts.AtomicDownload = options.AtomicDownload + } + if options.VerifySize != nil { + opts.VerifySize = options.VerifySize + } + if options.VerifyChecksum != nil { + opts.VerifyChecksum = options.VerifyChecksum + } + if strings.TrimSpace(options.TempSuffix) != "" { + opts.TempSuffix = options.TempSuffix + } + + return resolvedSFTPTransferOptions{ + BufferSize: opts.BufferSize, + Progress: opts.Progress, + RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), + RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff), + AtomicUpload: derefSFTPBool(opts.AtomicUpload, true), + AtomicDownload: derefSFTPBool(opts.AtomicDownload, true), + VerifySize: derefSFTPBool(opts.VerifySize, true), + VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), + TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), + } +} + +func derefSFTPBool(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func derefSFTPInt(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func derefSFTPDuration(value *time.Duration, fallback time.Duration) time.Duration { + if value == nil { + return fallback + } + return *value +} + +func normalizeSFTPTempSuffix(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return defaultSFTPTempSuffix + } + return trimmed +} + +func normalizeSFTPRetryCount(value int) int { + if value < 0 { + return 0 + } + return value +} + +func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error { + if err := ensureContext(ctx); err != nil { + return err + } + opts := normalizeSFTPTransferOptions(nil) + return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error { + return star.withIsolatedSFTPClient(ctx, fn) + }) +} + +func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func(*sftp.Client) error) error { + if err := ensureContext(ctx); err != nil { + return err + } + return star.withIsolatedSFTPClient(ctx, fn) +} + +func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) { + client, err := star.requireSSHClient() + if err != nil { + return nil, err + } + return sftp.NewClient(client) +} + +func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { + if err := ensureContext(ctx); err != nil { + return err + } + + client, err := star.CreateSftpClient() if err != nil { return err } - defer sftpC.Close() - return SftpTransferIn(src, dst, sftpC) + defer client.Close() + + return fn(client) +} + +func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) { + if star == nil { + return nil, errors.New("ssh client is nil") + } + + star.sftpMu.Lock() + defer star.sftpMu.Unlock() + + if star.sftpClient != nil { + return star.sftpClient, nil + } + + sshClient, err := star.requireSSHClient() + if err != nil { + return nil, err + } + + client, err := sftp.NewClient(sshClient) + if err != nil { + return nil, err + } + + star.sftpClient = client + return client, nil +} + +func (star *StarSSH) resetReusableSFTPClient() { + if star == nil { + return + } + + star.sftpMu.Lock() + defer star.sftpMu.Unlock() + + if star.sftpClient != nil { + _ = star.sftpClient.Close() + star.sftpClient = nil + } +} + +func (star *StarSSH) closeReusableSFTPClient() error { + if star == nil { + return nil + } + + star.sftpMu.Lock() + defer star.sftpMu.Unlock() + + if star.sftpClient == nil { + return nil + } + + err := star.sftpClient.Close() + star.sftpClient = nil + return err +} + +func (star *StarSSH) withReusableSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { + if err := ensureContext(ctx); err != nil { + return err + } + client, err := star.getReusableSFTPClient() + if err != nil { + return err + } + + return fn(client) +} + +func (star *StarSSH) runSFTPWithRetry( + ctx context.Context, + operation string, + localPath string, + remotePath string, + opts resolvedSFTPTransferOptions, + fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error, +) error { + return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error { + return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error { + return fn(ctx, client, opts) + }) + }) +} + +func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error { + return star.SftpTransferOutContext(context.Background(), localFilePath, remotePath, nil) +} + +func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { + return transferOutContext(ctx, client, localFilePath, remotePath, opts) + }) +} + +func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client) error { + return SftpTransferOutWithContext(context.Background(), localFilePath, remotePath, sftpClient, nil) +} + +func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error { + return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts) + }) +} + +func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) error { + return star.SftpTransferOutByteContext(context.Background(), localData, remotePath, nil) +} + +func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { + return transferOutByteContext(ctx, client, localData, remotePath, opts) + }) +} + +func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.Client) error { + return SftpTransferOutByteWithContext(context.Background(), localData, remotePath, sftpClient, nil) +} + +func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error { + return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts) + }) +} + +func (star *StarSSH) SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64)) error { + return star.SftpTransferOutContext(context.Background(), localFilePath, remotePath, &SFTPTransferOptions{ + BufferSize: bufcap, + Progress: rtefunc, + }) +} + +func SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { + return SftpTransferOutWithContext(context.Background(), localFilePath, remotePath, sftpClient, &SFTPTransferOptions{ + BufferSize: bufcap, + Progress: rtefunc, + }) +} + +func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) { + return star.SftpTransferInByteContext(context.Background(), remotePath, nil) +} + +func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { + opts := normalizeSFTPTransferOptions(options) + + var data []byte + err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { + out, runErr := transferInByteContext(ctx, client, remotePath, opts) + if runErr != nil { + return runErr + } + data = out + return nil + }) + if err != nil { + return nil, err + } + return data, nil +} + +func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, error) { + return SftpTransferInByteWithContext(context.Background(), remotePath, sftpClient, nil) +} + +func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) { + opts := normalizeSFTPTransferOptions(options) + + var data []byte + err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error { + out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts) + if runErr != nil { + return runErr + } + data = out + return nil + }) + if err != nil { + return nil, err + } + return data, nil +} + +func (star *StarSSH) SftpTransferIn(src, dst string) error { + return star.SftpTransferInContext(context.Background(), src, dst, nil) +} + +func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { + return transferInContext(ctx, client, src, dst, opts) + }) } func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error { - srcFile, err := sftpClient.Open(src) + return SftpTransferInWithContext(context.Background(), src, dst, sftpClient, nil) +} + +func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { + opts := normalizeSFTPTransferOptions(options) + return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error { + return transferInContext(ctx, sftpClient, src, dst, opts) + }) +} + +func (star *StarSSH) SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64)) error { + return star.SftpTransferInContext(context.Background(), src, dst, &SFTPTransferOptions{ + BufferSize: bufcap, + Progress: rtefunc, + }) +} + +func SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { + return SftpTransferInWithContext(context.Background(), src, dst, sftpClient, &SFTPTransferOptions{ + BufferSize: bufcap, + Progress: rtefunc, + }) +} + +func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFilePath string, remotePath string, opts resolvedSFTPTransferOptions) error { + if err := ensureContext(ctx); err != nil { + return err + } + if err := validateSFTPClient(sftpClient); err != nil { + return err + } + if strings.TrimSpace(localFilePath) == "" || strings.TrimSpace(remotePath) == "" { + return errors.New("local path and remote path must not be empty") + } + + srcFile, err := os.Open(localFilePath) if err != nil { return err } defer srcFile.Close() - //var localFileName = filepath.Base(src) - dstFile, err := os.Create(dst) + + stat, err := srcFile.Stat() if err != nil { return err } - defer dstFile.Close() - if _, err = srcFile.WriteTo(dstFile); err != nil { + tempPath, targetPath := buildUploadTargetPath(remotePath, opts) + targetInfo := atomicReplaceTarget{} + if tempPath != "" { + out, err := inspectRemoteAtomicTarget(sftpClient, remotePath) + if err != nil { + return err + } + targetInfo = out + } + if tempPath != "" { + defer func() { + _ = sftpClient.Remove(tempPath) + }() + } + + dstFile, err := sftpClient.Create(targetPath) + if err != nil { return err } + if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + _ = dstFile.Close() + return err + } + if err := dstFile.Close(); err != nil { + return err + } + + verifyPath := remotePath + if tempPath != "" { + verifyPath = tempPath + } + + if opts.VerifySize { + if err := sftpVerifyRemoteSizeFunc(sftpClient, verifyPath, stat.Size()); err != nil { + return err + } + } + if opts.VerifyChecksum { + localHash, err := sftpLocalFileSHA256Func(ctx, localFilePath) + if err != nil { + return err + } + remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, verifyPath) + if err != nil { + return err + } + if localHash != remoteHash { + return fmt.Errorf("checksum mismatch after upload: local=%s remote=%s", localHash, remoteHash) + } + } + + if tempPath != "" { + mode := stat.Mode() + if desiredMode, ok := determineAtomicReplaceMode(targetInfo, &mode); ok { + if err := applyRemoteFileMode(sftpClient, tempPath, desiredMode); err != nil { + return err + } + } + if err := renameRemoteAtomic(sftpClient, tempPath, remotePath); err != nil { + return err + } + tempPath = "" + } + return nil } -func (star *StarSSH) SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64)) error { - sftpC, err := star.CreateSftpClient() +func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localData []byte, remotePath string, opts resolvedSFTPTransferOptions) error { + if err := ensureContext(ctx); err != nil { + return err + } + if err := validateSFTPClient(sftpClient); err != nil { + return err + } + if strings.TrimSpace(remotePath) == "" { + return errors.New("remote path must not be empty") + } + + tempPath, targetPath := buildUploadTargetPath(remotePath, opts) + targetInfo := atomicReplaceTarget{} + if tempPath != "" { + out, err := inspectRemoteAtomicTarget(sftpClient, remotePath) + if err != nil { + return err + } + targetInfo = out + } + if tempPath != "" { + defer func() { + _ = sftpClient.Remove(tempPath) + }() + } + + dstFile, err := sftpClient.Create(targetPath) if err != nil { return err } - defer sftpC.Close() - return SftpTransferInFunc(src, dst, bufcap, rtefunc, sftpC) + + reader := bytes.NewReader(localData) + if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil { + _ = dstFile.Close() + return err + } + if err := dstFile.Close(); err != nil { + return err + } + + verifyPath := remotePath + if tempPath != "" { + verifyPath = tempPath + } + + if opts.VerifySize { + if err := sftpVerifyRemoteSizeFunc(sftpClient, verifyPath, int64(len(localData))); err != nil { + return err + } + } + if opts.VerifyChecksum { + localHash := checksumBytes(localData) + remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, verifyPath) + if err != nil { + return err + } + if localHash != remoteHash { + return fmt.Errorf("checksum mismatch after upload: local=%s remote=%s", localHash, remoteHash) + } + } + + if tempPath != "" { + if desiredMode, ok := determineAtomicReplaceMode(targetInfo, nil); ok { + if err := applyRemoteFileMode(sftpClient, tempPath, desiredMode); err != nil { + return err + } + } + if err := renameRemoteAtomic(sftpClient, tempPath, remotePath); err != nil { + return err + } + tempPath = "" + } + + return nil } -func SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { - num := 0 +func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst string, opts resolvedSFTPTransferOptions) error { + if err := ensureContext(ctx); err != nil { + return err + } + if err := validateSFTPClient(sftpClient); err != nil { + return err + } + if strings.TrimSpace(src) == "" || strings.TrimSpace(dst) == "" { + return errors.New("source path and destination path must not be empty") + } + srcFile, err := sftpClient.Open(src) if err != nil { return err } defer srcFile.Close() - stat, _ := srcFile.Stat() - filebig := float64(stat.Size()) - //var localFileName = filepath.Base(src) - dstFile, err := os.Create(dst) + + stat, err := srcFile.Stat() if err != nil { return err } - defer dstFile.Close() + + targetInfo := atomicReplaceTarget{} + if opts.AtomicDownload { + out, err := inspectLocalAtomicTarget(dst) + if err != nil { + return err + } + targetInfo = out + } + + dstFile, tempPath, err := createLocalTransferFile(dst, opts) + if err != nil { + return err + } + if tempPath != "" { + defer func() { + _ = os.Remove(tempPath) + }() + } + + if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + _ = dstFile.Close() + return err + } + if err := dstFile.Close(); err != nil { + return err + } + + verifyPath := dst + if tempPath != "" { + verifyPath = tempPath + } + + if opts.VerifySize { + if err := sftpVerifyLocalSizeFunc(verifyPath, stat.Size()); err != nil { + return err + } + } + if opts.VerifyChecksum { + localHash, err := sftpLocalFileSHA256Func(ctx, verifyPath) + if err != nil { + return err + } + remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, src) + if err != nil { + return err + } + if localHash != remoteHash { + return fmt.Errorf("checksum mismatch after download: local=%s remote=%s", localHash, remoteHash) + } + } + + if tempPath != "" { + mode := stat.Mode() + if desiredMode, ok := determineAtomicReplaceMode(targetInfo, &mode); ok { + if err := applyLocalFileMode(tempPath, desiredMode); err != nil { + return err + } + } + if err := renameLocalAtomic(tempPath, dst); err != nil { + return err + } + tempPath = "" + } + + return nil +} + +func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remotePath string, opts resolvedSFTPTransferOptions) ([]byte, error) { + if err := ensureContext(ctx); err != nil { + return nil, err + } + if err := validateSFTPClient(sftpClient); err != nil { + return nil, err + } + if strings.TrimSpace(remotePath) == "" { + return nil, errors.New("remote path must not be empty") + } + + srcFile, err := sftpClient.Open(remotePath) + if err != nil { + return nil, err + } + defer srcFile.Close() + + stat, err := srcFile.Stat() + if err != nil { + return nil, err + } + + var out bytes.Buffer + if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + return nil, err + } + + data := out.Bytes() + if opts.VerifySize && int64(len(data)) != stat.Size() { + return nil, fmt.Errorf("download size mismatch: local=%d remote=%d", len(data), stat.Size()) + } + if opts.VerifyChecksum { + localHash := checksumBytes(data) + remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, remotePath) + if err != nil { + return nil, err + } + if localHash != remoteHash { + return nil, fmt.Errorf("checksum mismatch after download: local=%s remote=%s", localHash, remoteHash) + } + } + + return data, nil +} + +func executeSFTPRetry( + ctx context.Context, + operation string, + localPath string, + remotePath string, + opts resolvedSFTPTransferOptions, + fn func(attempt int) error, +) error { + backoff := opts.RetryInitialBackoff + if backoff <= 0 { + backoff = defaultSFTPRetryInitialBackoff + } + + for attempt := 0; attempt <= opts.RetryCount; attempt++ { + if err := ensureContext(ctx); err != nil { + return wrapSFTPTransferError(operation, localPath, remotePath, attempt, SFTPErrorPermanent, err) + } + + err := fn(attempt) + if err == nil { + return nil + } + + category := classifySFTPError(err) + wrappedErr := wrapSFTPTransferError(operation, localPath, remotePath, attempt, category, err) + if category != SFTPErrorRetryable || attempt >= opts.RetryCount { + return wrappedErr + } + + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return wrapSFTPTransferError(operation, localPath, remotePath, attempt, SFTPErrorPermanent, ctx.Err()) + case <-timer.C: + } + if backoff < 4*time.Second { + backoff *= 2 + } + } + + return nil +} + +func wrapSFTPTransferError(operation, localPath, remotePath string, attempt int, category SFTPErrorCategory, err error) error { + if err == nil { + return nil + } + var transferErr *SFTPTransferError + if errors.As(err, &transferErr) { + return err + } + return &SFTPTransferError{ + Operation: operation, + LocalPath: localPath, + RemotePath: remotePath, + Attempt: attempt, + Category: category, + Err: err, + } +} + +func classifySFTPError(err error) SFTPErrorCategory { + if isRetryableTransferError(err) { + return SFTPErrorRetryable + } + return SFTPErrorPermanent +} + +func isRetryableTransferError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + if errors.Is(err, os.ErrNotExist) { + return false + } + + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() || netErr.Temporary() { + return true + } + } + + errText := strings.ToLower(err.Error()) + if strings.Contains(errText, "permission denied") || strings.Contains(errText, "no such file") { + return false + } + + retryableHints := []string{ + "connection reset", + "broken pipe", + "connection aborted", + "connection refused", + "connection lost", + "timeout", + "timed out", + "unexpected eof", + "use of closed network connection", + "transport is closing", + } + for _, hint := range retryableHints { + if strings.Contains(errText, hint) { + return true + } + } + + return false +} + +func validateSFTPClient(client *sftp.Client) error { + if client == nil { + return errors.New("sftp client is nil") + } + return nil +} + +func ensureContext(ctx context.Context) error { + if ctx == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { + buffer := make([]byte, normalizeBufferSize(bufSize)) + var copied int64 + + if progress != nil && total > 0 { + progress(0) + } + for { - buf := make([]byte, bufcap) - n, err := srcFile.Read(buf) - num += n - go rtefunc(float64(num) / filebig * 100) - dstFile.Write(buf[:n]) + if err := ensureContext(ctx); err != nil { + return copied, err + } + + n, readErr := src.Read(buffer) + if n > 0 { + if err := ensureContext(ctx); err != nil { + return copied, err + } + + written, writeErr := dst.Write(buffer[:n]) + copied += int64(written) + if writeErr != nil { + return copied, writeErr + } + if written != n { + return copied, io.ErrShortWrite + } + reportProgress(progress, copied, total) + } + + if readErr == io.EOF { + break + } + if readErr != nil { + return copied, readErr + } + } + + if progress != nil { + progress(100) + } + + return copied, nil +} + +func reportProgress(progress func(float64), copied int64, total int64) { + if progress == nil { + return + } + + if total <= 0 { + progress(100) + return + } + + percent := float64(copied) / float64(total) * 100 + if percent > 100 { + percent = 100 + } + progress(percent) +} + +func buildUploadTargetPath(remotePath string, opts resolvedSFTPTransferOptions) (tempPath string, targetPath string) { + targetPath = remotePath + if !opts.AtomicUpload { + return "", targetPath + } + + suffix := strings.TrimSpace(opts.TempSuffix) + if suffix == "" { + suffix = defaultSFTPTempSuffix + } + + tempPath = fmt.Sprintf("%s%s.%s", remotePath, suffix, newNonce(4)) + return tempPath, tempPath +} + +func createLocalTransferFile(localPath string, opts resolvedSFTPTransferOptions) (*os.File, string, error) { + if !opts.AtomicDownload { + file, err := os.Create(localPath) + if err != nil { + return nil, "", err + } + return file, "", nil + } + + dir := filepath.Dir(localPath) + pattern := fmt.Sprintf("%s%s.*", filepath.Base(localPath), normalizeSFTPTempSuffix(opts.TempSuffix)) + file, err := os.CreateTemp(dir, pattern) + if err != nil { + return nil, "", err + } + return file, file.Name(), nil +} + +func renameRemoteAtomic(client *sftp.Client, from, to string) error { + if from == to { + return nil + } + if _, err := inspectRemoteAtomicTarget(client, to); err != nil { + return err + } + + type posixRenamer interface { + PosixRename(string, string) error + } + + if renamer, ok := interface{}(client).(posixRenamer); ok { + if err := renamer.PosixRename(from, to); err == nil { + return nil + } + } + + renameErr := client.Rename(from, to) + if renameErr == nil { + return nil + } + + targetInfo, err := inspectRemoteAtomicTarget(client, to) + if err != nil { + return errors.Join(renameErr, err) + } + if !targetInfo.exists { + return renameErr + } + + backupPath := buildRenameBackupPath(to) + if err := client.Rename(to, backupPath); err != nil { + return errors.Join(renameErr, fmt.Errorf("backup existing target %q failed: %w", to, err)) + } + + if err := client.Rename(from, to); err != nil { + restoreErr := client.Rename(backupPath, to) + if restoreErr != nil { + return errors.Join(renameErr, err, fmt.Errorf("restore original target %q failed: %w", to, restoreErr)) + } + return errors.Join(renameErr, err) + } + + if err := removeRemotePath(client, backupPath); err != nil && !isNotExistError(err) { + return fmt.Errorf("rename succeeded but backup cleanup %q failed: %w", backupPath, err) + } + return nil +} + +func renameLocalAtomic(from, to string) error { + if from == to { + return nil + } + if _, err := inspectLocalAtomicTarget(to); err != nil { + return err + } + + renameErr := os.Rename(from, to) + if renameErr == nil { + return nil + } + + targetInfo, err := inspectLocalAtomicTarget(to) + if err != nil { + return errors.Join(renameErr, err) + } + if !targetInfo.exists { + return renameErr + } + + backupPath := buildLocalRenameBackupPath(to) + if err := os.Rename(to, backupPath); err != nil { + return errors.Join(renameErr, fmt.Errorf("backup existing local target %q failed: %w", to, err)) + } + + if err := os.Rename(from, to); err != nil { + restoreErr := os.Rename(backupPath, to) + if restoreErr != nil { + return errors.Join(renameErr, err, fmt.Errorf("restore original local target %q failed: %w", to, restoreErr)) + } + return errors.Join(renameErr, err) + } + + if err := os.Remove(backupPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("rename succeeded but local backup cleanup %q failed: %w", backupPath, err) + } + return nil +} + +func buildRenameBackupPath(targetPath string) string { + return fmt.Sprintf("%s%s.rename-backup.%s", targetPath, defaultSFTPTempSuffix, newNonce(4)) +} + +func buildLocalRenameBackupPath(targetPath string) string { + dir := filepath.Dir(targetPath) + name := fmt.Sprintf("%s%s.rename-backup.%s", filepath.Base(targetPath), defaultSFTPTempSuffix, newNonce(4)) + return filepath.Join(dir, name) +} + +func remotePathExists(client *sftp.Client, remotePath string) (bool, error) { + if client == nil { + return false, errors.New("sftp client is nil") + } + _, err := client.Lstat(remotePath) + if err == nil { + return true, nil + } + if isNotExistError(err) { + return false, nil + } + return false, err +} + +func localPathExists(localPath string) (bool, error) { + _, err := os.Lstat(localPath) + if err == nil { + return true, nil + } + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, err +} + +func verifyRemoteSize(client *sftp.Client, remotePath string, expected int64) error { + info, err := client.Stat(remotePath) + if err != nil { + return err + } + if info.Size() != expected { + return fmt.Errorf("remote size mismatch: got %d want %d", info.Size(), expected) + } + return nil +} + +func verifyLocalSize(localPath string, expected int64) error { + info, err := os.Stat(localPath) + if err != nil { + return err + } + if info.Size() != expected { + return fmt.Errorf("local size mismatch: got %d want %d", info.Size(), expected) + } + return nil +} + +func localFileSHA256(ctx context.Context, path string) (string, error) { + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + return readerSHA256(ctx, file) +} + +func remoteFileSHA256(ctx context.Context, client *sftp.Client, remotePath string) (string, error) { + file, err := client.Open(remotePath) + if err != nil { + return "", err + } + defer file.Close() + return readerSHA256(ctx, file) +} + +func checksumBytes(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func readerSHA256(ctx context.Context, reader io.Reader) (string, error) { + hasher := sha256.New() + buf := make([]byte, normalizeBufferSize(defaultTransferBufferSize)) + + for { + if err := ensureContext(ctx); err != nil { + return "", err + } + + n, err := reader.Read(buf) + if n > 0 { + if _, writeErr := hasher.Write(buf[:n]); writeErr != nil { + return "", writeErr + } + } + if err == io.EOF { break } if err != nil { - return err + return "", err } } + + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func isNotExistError(err error) bool { + if err == nil { + return false + } + if os.IsNotExist(err) { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "no such file") +} + +func validateRemotePath(remotePath string) error { + if strings.TrimSpace(remotePath) == "" { + return errors.New("remote path must not be empty") + } + return nil +} + +func inspectRemoteAtomicTarget(client *sftp.Client, remotePath string) (atomicReplaceTarget, error) { + if err := validateSFTPClient(client); err != nil { + return atomicReplaceTarget{}, err + } + + info, err := client.Lstat(remotePath) + if err != nil { + if isNotExistError(err) { + return atomicReplaceTarget{}, nil + } + return atomicReplaceTarget{}, err + } + if err := validateAtomicReplaceTarget(remotePath, info); err != nil { + return atomicReplaceTarget{}, err + } + return atomicReplaceTarget{ + exists: true, + mode: info.Mode(), + }, nil +} + +func inspectLocalAtomicTarget(localPath string) (atomicReplaceTarget, error) { + info, err := os.Lstat(localPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return atomicReplaceTarget{}, nil + } + return atomicReplaceTarget{}, err + } + if err := validateAtomicReplaceTarget(localPath, info); err != nil { + return atomicReplaceTarget{}, err + } + return atomicReplaceTarget{ + exists: true, + mode: info.Mode(), + }, nil +} + +func validateAtomicReplaceTarget(targetPath string, info os.FileInfo) error { + if info == nil { + return nil + } + mode := info.Mode() + switch { + case mode&os.ModeSymlink != 0: + return fmt.Errorf("atomic overwrite target %q is a symlink", targetPath) + case mode.IsRegular(): + return nil + default: + return fmt.Errorf("atomic overwrite target %q is %s", targetPath, describeFileInfoType(info)) + } +} + +func describeFileInfoType(info os.FileInfo) string { + if info == nil { + return "unknown" + } + mode := info.Mode() + switch { + case mode&os.ModeSymlink != 0: + return "a symlink" + case mode.IsDir(): + return "a directory" + case mode&os.ModeNamedPipe != 0: + return "a named pipe" + case mode&os.ModeSocket != 0: + return "a socket" + case mode&os.ModeDevice != 0 && mode&os.ModeCharDevice != 0: + return "a character device" + case mode&os.ModeDevice != 0: + return "a block device" + default: + return "not a regular file" + } +} + +func determineAtomicReplaceMode(target atomicReplaceTarget, sourceMode *os.FileMode) (os.FileMode, bool) { + if target.exists { + return normalizePreservedFileMode(target.mode), true + } + if sourceMode == nil { + return 0, false + } + return normalizePreservedFileMode(*sourceMode), true +} + +func normalizePreservedFileMode(mode os.FileMode) os.FileMode { + return mode & preservedFileModeBits +} + +func applyLocalFileMode(localPath string, mode os.FileMode) error { + return os.Chmod(localPath, normalizePreservedFileMode(mode)) +} + +func applyRemoteFileMode(client *sftp.Client, remotePath string, mode os.FileMode) error { + if err := validateSFTPClient(client); err != nil { + return err + } + return client.Chmod(remotePath, normalizePreservedFileMode(mode)) +} + +func removeRemotePath(client *sftp.Client, remotePath string) error { + info, err := client.Lstat(remotePath) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 { + return client.Remove(remotePath) + } + if info.IsDir() { + return client.RemoveDirectory(remotePath) + } + return client.Remove(remotePath) +} + +func removeRemoteAll(ctx context.Context, client *sftp.Client, remotePath string) error { + if err := ensureContext(ctx); err != nil { + return err + } + + info, err := client.Lstat(remotePath) + if err != nil { + if isNotExistError(err) { + return nil + } + return err + } + if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() { + if err := client.Remove(remotePath); err != nil && !isNotExistError(err) { + return err + } + return nil + } + + entries, err := client.ReadDir(remotePath) + if err != nil { + return err + } + for _, entry := range entries { + childPath := path.Join(remotePath, entry.Name()) + if err := removeRemoteAll(ctx, client, childPath); err != nil { + return err + } + } + + if err := client.RemoveDirectory(remotePath); err != nil && !isNotExistError(err) { + return err + } return nil } diff --git a/sftp_test.go b/sftp_test.go new file mode 100644 index 0000000..1170c6f --- /dev/null +++ b/sftp_test.go @@ -0,0 +1,475 @@ +package starssh + +import ( + "context" + "errors" + "io" + "net" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/pkg/sftp" +) + +func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) { + opts := normalizeSFTPTransferOptions(nil) + if !opts.AtomicUpload { + t.Fatal("expected atomic upload to default to enabled") + } + if !opts.AtomicDownload { + t.Fatal("expected atomic download to default to enabled") + } +} + +func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + + if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + + verifyErr := errors.New("verify failed") + var verifiedPath string + oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc + sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error { + verifiedPath = remotePath + return verifyErr + } + t.Cleanup(func() { + sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize + }) + + err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + if !errors.Is(err, verifyErr) { + t.Fatalf("expected verify failure, got %v", err) + } + if verifiedPath == remotePath { + t.Fatal("expected upload verification to run against temp path before final rename") + } + + data, err := os.ReadFile(remotePath) + if err != nil { + t.Fatalf("read remote file: %v", err) + } + if string(data) != "original remote" { + t.Fatalf("remote target was replaced on verify failure: %q", string(data)) + } + assertNoTransferTemps(t, remotePath) +} + +func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remoteRealPath := filepath.Join(root, "remote-real.txt") + remotePath := filepath.Join(root, "remote-link.txt") + + if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil { + t.Fatalf("write remote backing file: %v", err) + } + if err := os.Symlink(remoteRealPath, remotePath); err != nil { + t.Skipf("symlink unsupported: %v", err) + } + + err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + if err == nil || !strings.Contains(err.Error(), "symlink") { + t.Fatalf("expected symlink rejection, got %v", err) + } + + info, err := os.Lstat(remotePath) + if err != nil { + t.Fatalf("lstat remote symlink: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatal("expected remote target to remain a symlink") + } + + data, err := os.ReadFile(remoteRealPath) + if err != nil { + t.Fatalf("read remote backing file: %v", err) + } + if string(data) != "original remote" { + t.Fatalf("remote backing file changed unexpectedly: %q", string(data)) + } + assertNoTransferTemps(t, remotePath) +} + +func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote-dir") + + if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.Mkdir(remotePath, 0o755); err != nil { + t.Fatalf("mkdir remote target: %v", err) + } + + err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + if err == nil || !strings.Contains(err.Error(), "directory") { + t.Fatalf("expected directory rejection, got %v", err) + } + + info, err := os.Stat(remotePath) + if err != nil { + t.Fatalf("stat remote directory: %v", err) + } + if !info.IsDir() { + t.Fatal("expected remote target to remain a directory") + } + assertNoTransferTemps(t, remotePath) +} + +func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + + if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.Chmod(remotePath, 0o755); err != nil { + t.Fatalf("chmod remote file: %v", err) + } + + if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + t.Fatalf("transfer out: %v", err) + } + + assertMode(t, remotePath, 0o755) + assertFileContent(t, remotePath, "new payload") +} + +func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + + if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.Chmod(localPath, 0o751); err != nil { + t.Fatalf("chmod local file: %v", err) + } + + if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + t.Fatalf("transfer out: %v", err) + } + + assertMode(t, remotePath, 0o751) + assertFileContent(t, remotePath, "new payload") +} + +func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + remotePath := filepath.Join(root, "remote.txt") + + if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.Chmod(remotePath, 0o755); err != nil { + t.Fatalf("chmod remote file: %v", err) + } + + if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + t.Fatalf("transfer out bytes: %v", err) + } + + assertMode(t, remotePath, 0o755) + assertFileContent(t, remotePath, "byte payload") +} + +func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.txt") + + if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + + verifyErr := errors.New("verify local failed") + var verifiedPath string + oldVerifyLocalSize := sftpVerifyLocalSizeFunc + sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error { + verifiedPath = localPath + return verifyErr + } + t.Cleanup(func() { + sftpVerifyLocalSizeFunc = oldVerifyLocalSize + }) + + err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + if !errors.Is(err, verifyErr) { + t.Fatalf("expected verify failure, got %v", err) + } + if verifiedPath == dstPath { + t.Fatal("expected download verification to run against temp path before final rename") + } + + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("read local file: %v", err) + } + if string(data) != "original local" { + t.Fatalf("local target was replaced on verify failure: %q", string(data)) + } + assertNoTransferTemps(t, dstPath) +} + +func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + localRealPath := filepath.Join(root, "local-real.txt") + dstPath := filepath.Join(root, "local-link.txt") + + if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("write local backing file: %v", err) + } + if err := os.Symlink(localRealPath, dstPath); err != nil { + t.Skipf("symlink unsupported: %v", err) + } + + err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + if err == nil || !strings.Contains(err.Error(), "symlink") { + t.Fatalf("expected symlink rejection, got %v", err) + } + + info, err := os.Lstat(dstPath) + if err != nil { + t.Fatalf("lstat local symlink: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatal("expected local target to remain a symlink") + } + assertFileContent(t, localRealPath, "original local") + assertNoTransferTemps(t, dstPath) +} + +func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local-dir") + + if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.Mkdir(dstPath, 0o755); err != nil { + t.Fatalf("mkdir local target: %v", err) + } + + err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + if err == nil || !strings.Contains(err.Error(), "directory") { + t.Fatalf("expected directory rejection, got %v", err) + } + + info, err := os.Stat(dstPath) + if err != nil { + t.Fatalf("stat local directory: %v", err) + } + if !info.IsDir() { + t.Fatal("expected local target to remain a directory") + } + assertNoTransferTemps(t, dstPath) +} + +func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.sh") + + if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil { + t.Fatalf("write local file: %v", err) + } + if err := os.Chmod(dstPath, 0o755); err != nil { + t.Fatalf("chmod local file: %v", err) + } + + if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { + t.Fatalf("transfer in: %v", err) + } + + assertMode(t, dstPath, 0o755) + assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") +} + +func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.sh") + dstPath := filepath.Join(root, "local.sh") + + if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.Chmod(srcPath, 0o751); err != nil { + t.Fatalf("chmod remote file: %v", err) + } + + if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { + t.Fatalf("transfer in: %v", err) + } + + assertMode(t, dstPath, 0o751) + assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") +} + +func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.txt") + + if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + + copyErr := errors.New("copy failed") + var copyTargetPath string + oldCopy := sftpCopyWithProgressFunc + sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { + file, ok := dst.(*os.File) + if !ok { + t.Fatalf("expected local temp file writer, got %T", dst) + } + copyTargetPath = file.Name() + + buf := make([]byte, 8) + n, readErr := src.Read(buf) + if readErr != nil && !errors.Is(readErr, io.EOF) { + return 0, readErr + } + if n > 0 { + written, err := dst.Write(buf[:n]) + if err != nil { + return int64(written), err + } + return int64(written), copyErr + } + return 0, copyErr + } + t.Cleanup(func() { + sftpCopyWithProgressFunc = oldCopy + }) + + err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + if !errors.Is(err, copyErr) { + t.Fatalf("expected copy failure, got %v", err) + } + if copyTargetPath == dstPath { + t.Fatal("expected partial download writes to stay on temp path") + } + + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("read local file: %v", err) + } + if string(data) != "original local" { + t.Fatalf("local target was modified by partial download: %q", string(data)) + } + assertNoTransferTemps(t, dstPath) +} + +func newSFTPTestClient(t *testing.T) *sftp.Client { + t.Helper() + + serverConn, clientConn := net.Pipe() + server, err := sftp.NewServer(serverConn) + if err != nil { + t.Fatalf("create sftp server: %v", err) + } + + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- server.Serve() + }() + + client, err := sftp.NewClientPipe(clientConn, clientConn) + if err != nil { + _ = server.Close() + t.Fatalf("create sftp client: %v", err) + } + + t.Cleanup(func() { + _ = client.Close() + _ = server.Close() + serveErr := <-serveErrCh + if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil { + return + } + t.Errorf("unexpected sftp server error: %v", serveErr) + }) + + return client +} + +func assertNoTransferTemps(t *testing.T, targetPath string) { + t.Helper() + + matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*") + if err != nil { + t.Fatalf("glob temp files: %v", err) + } + if len(matches) != 0 { + t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches) + } +} + +func assertMode(t *testing.T, targetPath string, want os.FileMode) { + t.Helper() + + info, err := os.Stat(targetPath) + if err != nil { + t.Fatalf("stat %q: %v", targetPath, err) + } + if got := info.Mode().Perm(); got != want { + t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want) + } +} + +func assertFileContent(t *testing.T, targetPath string, want string) { + t.Helper() + + data, err := os.ReadFile(targetPath) + if err != nil { + t.Fatalf("read %q: %v", targetPath, err) + } + if string(data) != want { + t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want) + } +} diff --git a/shell.go b/shell.go new file mode 100644 index 0000000..f30cdc4 --- /dev/null +++ b/shell.go @@ -0,0 +1,592 @@ +package starssh + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "strings" + "time" +) + +var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells") + +type ShellRequest struct { + Command string + Timeout time.Duration + Keyword string + UseWaitDefault *bool +} + +// NewShell creates the legacy prompt-driven POSIX shell helper. +// For raw interactive terminal flows, prefer NewTerminal. +func (s *StarSSH) NewShell() (shell *StarShell, err error) { + shell = &StarShell{ + UseWaitDefault: true, + WaitTimeout: defaultShellWaitTimeout, + isecho: true, + iscolor: true, + promptToken: defaultShellPromptToken, + } + + shell.Session, err = s.NewPTYSession(nil) + if err != nil { + return nil, err + } + + shell.in, err = shell.Session.StdinPipe() + if err != nil { + _ = shell.Session.Close() + return nil, err + } + + stdout, err := shell.Session.StdoutPipe() + if err != nil { + _ = shell.Session.Close() + return nil, err + } + shell.out = bufio.NewReader(stdout) + + stderr, err := shell.Session.StderrPipe() + if err != nil { + _ = shell.Session.Close() + return nil, err + } + shell.er = bufio.NewReader(stderr) + + if err := shell.Session.Shell(); err != nil { + _ = shell.Session.Close() + return nil, err + } + + go shell.watchSession() + shell.gohub() + + if err := shell.configurePrompt(context.Background()); err != nil { + _ = shell.Session.Close() + return nil, err + } + + shell.Clear() + return shell, nil +} + +func (s *StarShell) configurePrompt(ctx context.Context) error { + if s == nil { + return errors.New("shell is nil") + } + if ctx == nil { + ctx = context.Background() + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout) + defer cancel() + ctx = timeoutCtx + } + + prompt := s.promptToken + " " + setupCommands := []string{ + "unset PROMPT_COMMAND >/dev/null 2>&1 || true", + fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)), + } + + s.Clear() + for _, cmd := range setupCommands { + if err := s.WriteCommand(cmd); err != nil { + return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err) + } + } + + probeToken := "__STARSSH_POSIX_READY__" + newNonce(6) + if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil { + return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err) + } + + ticker := time.NewTicker(defaultShellPollInterval) + defer ticker.Stop() + + for { + outRaw, errRaw, runErr := s.readState() + if runErr != nil { + return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr) + } + + outs := normalizeShellOutput(stripControlSequences(string(outRaw))) + if strings.Contains(outs, probeToken) { + s.Clear() + return nil + } + + errs := normalizeShellOutput(stripControlSequences(string(errRaw))) + if looksLikeNonPOSIXShellError(errs) { + return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs) + } + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly) + } + return ctx.Err() + case <-ticker.C: + } + } +} + +func (s *StarShell) watchSession() { + if s == nil || s.Session == nil { + return + } + if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) { + s.setError(err) + } +} + +func (s *StarShell) Close() error { + if s == nil { + return nil + } + + var closeErr error + s.closeOnce.Do(func() { + if s.Session == nil { + return + } + closeErr = s.Session.Close() + }) + return closeErr +} + +func (s *StarShell) SwitchNoColor(run bool) { + s.rw.Lock() + defer s.rw.Unlock() + s.iscolor = run +} + +func (s *StarShell) SwitchEcho(run bool) { + s.rw.Lock() + defer s.rw.Unlock() + s.isecho = run +} + +func (s *StarShell) TrimColor(str string) string { + s.rw.RLock() + shouldTrim := s.iscolor + s.rw.RUnlock() + + if shouldTrim { + return SedColor(str) + } + return str +} + +/* +本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false] +*/ +func (s *StarShell) SwitchPrint(run bool) { + s.rw.Lock() + defer s.rw.Unlock() + s.isprint = run +} + +/* +本函数控制是否立即处理远程Shell输出每一行内容[true|false] +*/ +func (s *StarShell) SwitchFunc(run bool) { + s.rw.Lock() + defer s.rw.Unlock() + s.isfuncs = run +} + +func (s *StarShell) SetFunc(funcs func(string)) { + s.rw.Lock() + defer s.rw.Unlock() + s.funcs = funcs +} + +func (s *StarShell) Clear() { + s.rw.Lock() + defer s.rw.Unlock() + s.outbyte = []byte{} + s.errbyte = []byte{} +} + +func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) { + s.Clear() + defer s.Clear() + return s.Shell(cmd, sleep) +} + +func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) { + s.commandMu.Lock() + defer s.commandMu.Unlock() + + if err := s.WriteCommand(cmd); err != nil { + return "", "", err + } + + outRaw, errRaw, runErr := s.GetResult(sleep) + if runErr != nil { + return "", "", runErr + } + + outText := s.TrimColor(strings.TrimSpace(string(outRaw))) + + s.rw.RLock() + echoEnabled := s.isecho + s.rw.RUnlock() + if echoEnabled { + outText = stripCommandEchoFromOutput(outText, cmd) + } + + return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil +} + +func (s *StarShell) ShellWait(cmd string) (string, string, error) { + result, err := s.Run(context.Background(), ShellRequest{ + Command: cmd, + }) + if err != nil { + return "", "", err + } + return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError() +} + +func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) { + return s.Run(ctx, ShellRequest{ + Command: command, + }) +} + +func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) { + if s == nil { + return nil, errors.New("shell is nil") + } + if ctx == nil { + ctx = context.Background() + } + + s.commandMu.Lock() + defer s.commandMu.Unlock() + + if strings.TrimSpace(req.Command) == "" { + return nil, errors.New("command is empty") + } + + s.rw.RLock() + useDefault := s.UseWaitDefault + keyword := s.Keyword + waitTimeout := s.WaitTimeout + promptToken := s.promptToken + s.rw.RUnlock() + if req.UseWaitDefault != nil { + useDefault = *req.UseWaitDefault + } + if strings.TrimSpace(req.Keyword) != "" { + keyword = req.Keyword + } + if req.Timeout > 0 { + waitTimeout = req.Timeout + } + if !useDefault && keyword == "" { + return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set") + } + if waitTimeout <= 0 { + waitTimeout = defaultShellWaitTimeout + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout) + defer cancel() + ctx = timeoutCtx + } + startAt := time.Now() + + s.Clear() + defer s.Clear() + + beginToken, endToken := newCommandTokens() + markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken) + + if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil { + return nil, err + } + if err := s.WriteCommand(req.Command); err != nil { + return nil, err + } + if err := s.WriteCommand(markerCmd); err != nil { + return nil, err + } + + var ( + outc string + errc string + exitCode int + done bool + ) + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(defaultShellPollInterval): + } + + outRaw, errRaw, runErr := s.readState() + if runErr != nil { + return nil, runErr + } + + outs := normalizeShellOutput(stripControlSequences(string(outRaw))) + errs := normalizeShellOutput(stripControlSequences(string(errRaw))) + + s.rw.RLock() + useDefault = s.UseWaitDefault + keyword = s.Keyword + s.rw.RUnlock() + + if useDefault { + segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken) + if parseErr != nil { + return nil, parseErr + } + if found { + outc = segment + errc = errs + exitCode = rc + done = true + break + } + } + + if keyword != "" { + if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) { + outc = outs + errc = errs + done = true + break + } + } + } + if !done { + return nil, errors.New("failed to collect shell result") + } + + outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken) + errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken) + outc = stripCommandEchoFromOutput(outc, req.Command) + + stdoutText := strings.TrimSpace(outc) + stderrText := strings.TrimSpace(errc) + result := &ExecResult{ + Command: req.Command, + Stdout: []byte(stdoutText), + Stderr: []byte(stderrText), + Combined: combineCommandOutput(stdoutText, stderrText), + Duration: time.Since(startAt), + } + if useDefault { + result.ExitCode = exitCode + } + return result, nil +} + +func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) { + lines := strings.Split(stdout, "\n") + beginLine := -1 + for i, line := range lines { + if strings.TrimSpace(line) == beginToken { + beginLine = i + break + } + } + if beginLine < 0 { + return "", 0, false, nil + } + + segment := strings.Join(lines[beginLine+1:], "\n") + before, rc, found, err := splitByEndToken(segment, endToken) + if err != nil { + return "", 0, false, err + } + if !found { + return "", 0, false, nil + } + return strings.TrimSpace(before), rc, true, nil +} + +func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) { + if sleep > 0 { + time.Sleep(time.Millisecond * time.Duration(sleep)) + } + return s.readState() +} + +func (s *StarShell) WriteCommand(cmd string) error { + return s.Write([]byte(cmd + "\n")) +} + +func (s *StarShell) Write(bstr []byte) error { + if s == nil { + return errors.New("shell is nil") + } + if s.in == nil { + return errors.New("shell stdin is not initialized") + } + if _, _, runErr := s.readState(); runErr != nil { + return runErr + } + + s.writeMu.Lock() + defer s.writeMu.Unlock() + + _, err := s.in.Write(bstr) + return err +} + +func (s *StarShell) gohub() { + if s.er != nil { + go s.streamPump(s.er, true) + } + if s.out != nil { + go s.streamPump(s.out, false) + } +} + +func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) { + var cache []byte + + for { + read, err := reader.ReadByte() + if err == io.EOF { + return + } + if err != nil { + s.setError(err) + return + } + + s.rw.Lock() + if isStderr { + s.errbyte = append(s.errbyte, read) + } else { + s.outbyte = append(s.outbyte, read) + } + printEnabled := s.isprint + funcEnabled := s.isfuncs && s.funcs != nil + lineHandler := s.funcs + trimColor := s.iscolor + s.rw.Unlock() + + if printEnabled { + fmt.Print(string([]byte{read})) + } + + cache = append(cache, read) + if read == '\n' { + if funcEnabled { + line := strings.TrimSpace(string(cache)) + if trimColor { + line = SedColor(line) + } + go lineHandler(line) + } + cache = cache[:0] + } + } +} + +func (s *StarShell) setError(err error) { + if err == nil || errors.Is(err, io.EOF) { + return + } + + s.rw.Lock() + defer s.rw.Unlock() + if s.errors == nil { + s.errors = err + } +} + +func (s *StarShell) readState() ([]byte, []byte, error) { + s.rw.RLock() + defer s.rw.RUnlock() + + outCopy := make([]byte, len(s.outbyte)) + copy(outCopy, s.outbyte) + + errCopy := make([]byte, len(s.errbyte)) + copy(errCopy, s.errbyte) + + if s.errors != nil { + return outCopy, errCopy, s.errors + } + + return outCopy, errCopy, nil +} + +func stripCommandEchoFromOutput(output string, cmd string) string { + if output == "" || cmd == "" { + return strings.TrimSpace(output) + } + + lines := strings.Split(output, "\n") + cmdLines := strings.Split(cmd, "\n") + + for _, cmdLine := range cmdLines { + trimmedCmd := strings.TrimSpace(cmdLine) + if trimmedCmd == "" { + continue + } + + for i, line := range lines { + if strings.TrimSpace(line) == trimmedCmd { + lines = append(lines[:i], lines[i+1:]...) + break + } + } + } + + return strings.TrimSpace(strings.Join(lines, "\n")) +} + +func looksLikeNonPOSIXShellError(output string) bool { + if strings.TrimSpace(output) == "" { + return false + } + + lower := strings.ToLower(output) + indicators := []string{ + "is not recognized as an internal or external command", + "the term ", + "command not found", + "unknown command", + "not found", + "not recognized", + } + for _, indicator := range indicators { + if strings.Contains(lower, indicator) { + return true + } + } + return false +} + +func (s *StarShell) GetUid() string { + res, _, _ := s.ShellWait("id -u") + return strings.TrimSpace(res) +} + +func (s *StarShell) GetGid() string { + res, _, _ := s.ShellWait("id -g") + return strings.TrimSpace(res) +} + +func (s *StarShell) GetUser() string { + res, _, _ := s.ShellWait("id -un") + return strings.TrimSpace(res) +} + +func (s *StarShell) GetGroup() string { + res, _, _ := s.ShellWait("id -gn") + return strings.TrimSpace(res) +} diff --git a/ssh.go b/ssh.go deleted file mode 100644 index e2f07e4..0000000 --- a/ssh.go +++ /dev/null @@ -1,636 +0,0 @@ -package starssh - -import ( - "bufio" - "encoding/base64" - "errors" - "fmt" - "golang.org/x/crypto/ssh" - "io" - "io/ioutil" - "net" - "regexp" - "strings" - "sync" - "time" -) - -type StarSSH struct { - Client *ssh.Client - PublicKey ssh.PublicKey - PubkeyBase64 string - Hostname string - RemoteAddr net.Addr - Banner string - LoginInfo LoginInput - online bool -} - -type LoginInput struct { - KeyExchanges []string - Ciphers []string - MACs []string - User string - Password string - Prikey string - PrikeyPwd string - Addr string - Port int - Timeout time.Duration - HostKeyCallback func(string, net.Addr, ssh.PublicKey) error - BannerCallback func(string) error -} -type StarShell struct { - Keyword string - UseWaitDefault bool - Session *ssh.Session - in io.Writer - out *bufio.Reader - er *bufio.Reader - outbyte []byte - errbyte []byte - lastout int64 - errors error - isprint bool - isfuncs bool - iscolor bool - isecho bool - rw sync.RWMutex - funcs func(string) -} - -func Login(info LoginInput) (*StarSSH, error) { - var ( - auth []ssh.AuthMethod - clientConfig *ssh.ClientConfig - config ssh.Config - err error - ) - sshInfo := new(StarSSH) - // get auth method - auth = make([]ssh.AuthMethod, 0) - if info.Prikey == "" { - keyboardInteractiveChallenge := func( - user, - instruction string, - questions []string, - echos []bool, - ) (answers []string, err error) { - if len(questions) == 0 { - return []string{}, nil - } - return []string{info.Password}, nil - } - auth = append(auth, ssh.Password(info.Password)) - auth = append(auth, ssh.KeyboardInteractive(keyboardInteractiveChallenge)) - } else { - pemBytes := []byte(info.Prikey) - var signer ssh.Signer - if info.PrikeyPwd == "" { - signer, err = ssh.ParsePrivateKey(pemBytes) - } else { - signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) - } - if err != nil { - return nil, err - } - auth = append(auth, ssh.PublicKeys(signer)) - } - - if len(info.Ciphers) == 0 { - config = ssh.Config{ - Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc", "chacha20-poly1305@openssh.com"}, - } - } else { - config = ssh.Config{ - Ciphers: info.Ciphers, - } - } - - if len(info.MACs) != 0 { - config.MACs = info.MACs - } - if len(info.KeyExchanges) != 0 { - config.KeyExchanges = info.KeyExchanges - } - - if info.Timeout == 0 { - info.Timeout = time.Second * 5 - } - hostKeycbfunc := func(hostname string, remote net.Addr, key ssh.PublicKey) error { - - sshInfo.PublicKey = key - sshInfo.RemoteAddr = remote - sshInfo.Hostname = hostname - if info.HostKeyCallback != nil { - return info.HostKeyCallback(hostname, remote, key) - } - return nil - } - - bannercbfunc := func(banner string) error { - sshInfo.Banner = banner - if info.BannerCallback != nil { - return info.BannerCallback(banner) - } - return nil - } - - clientConfig = &ssh.ClientConfig{ - User: info.User, - Auth: auth, - Timeout: info.Timeout, - Config: config, - HostKeyCallback: hostKeycbfunc, - BannerCallback: bannercbfunc, - } - - // connet to ssh - - sshInfo.LoginInfo = info - sshInfo.Client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", info.Addr, info.Port), clientConfig) - if err == nil && sshInfo.PublicKey != nil { - sshInfo.online = true - sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal()) - } - return sshInfo, err -} - -func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) { - var info = LoginInput{ - Addr: host, - Port: port, - Timeout: timeout, - User: user, - } - if prikeyPath != "" { - prikey, err := ioutil.ReadFile(prikeyPath) - if err != nil { - return nil, err - } - info.Prikey = string(prikey) - if passwd != "" { - info.PrikeyPwd = passwd - } - } else { - info.Password = passwd - } - return Login(info) -} - -func (s *StarShell) ShellWait(cmd string) (string, string, error) { - var outc, errc string = " ", " " - s.Clear() - defer s.Clear() - echo := "echo b7Y85R56TUY6R5UTb612" - err := s.WriteCommand(cmd) - if err != nil { - return "", "", err - } - time.Sleep(time.Millisecond * 20) - err = s.WriteCommand(echo) - if err != nil { - return "", "", err - } - for { - time.Sleep(time.Millisecond * 120) - outs := string(s.outbyte) - errs := string(s.errbyte) - outs = strings.TrimSpace(strings.ReplaceAll(outs, "\r\n", "\n")) - errs = strings.TrimSpace(strings.ReplaceAll(errs, "\r\n", "\n")) - if len(outs) >= len(cmd+"\n"+echo) && outs[0:len(cmd+"\n"+echo)] == cmd+"\n"+echo { - outs = outs[len(cmd+"\n"+echo):] - } else if len(outs) >= len(cmd) && outs[0:len(cmd)] == cmd { - outs = outs[len(cmd):] - } - if len(errs) >= len(cmd) && errs[0:len(cmd)] == cmd { - errs = errs[len(cmd):] - } - if s.UseWaitDefault { - if strings.Index(string(outs), "b7Y85R56TUY6R5UTb612") >= 0 { - list := strings.Split(string(outs), "\n") - for _, v := range list { - if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 { - outc += v + "\n" - } - } - break - } - if strings.Index(string(errs), "b7Y85R56TUY6R5UTb612") >= 0 { - list := strings.Split(string(errs), "\n") - for _, v := range list { - if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 { - errc += v + "\n" - } - } - break - } - } - if s.Keyword != "" { - if strings.Index(string(outs), s.Keyword) >= 0 { - list := strings.Split(string(outs), "\n") - for _, v := range list { - if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 { - outc += v + "\n" - } - } - break - } - if strings.Index(string(errs), s.Keyword) >= 0 { - list := strings.Split(string(errs), "\n") - for _, v := range list { - if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 { - errc += v + "\n" - } - } - break - } - } - } - return s.TrimColor(strings.TrimSpace(outc)), s.TrimColor(strings.TrimSpace(errc)), err -} - -func (s *StarShell) Close() error { - return s.Session.Close() -} - -func (s *StarShell) SwitchNoColor(is bool) { - s.iscolor = is -} - -func (s *StarShell) SwitchEcho(is bool) { - s.isecho = is -} - -func (s *StarShell) TrimColor(str string) string { - if s.iscolor { - return SedColor(str) - } - return str -} - -/* -本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false] -*/ -func (s *StarShell) SwitchPrint(run bool) { - s.isprint = run -} - -/* -本函数控制是否立即处理远程Shell输出每一行内容[true|false] -*/ -func (s *StarShell) SwitchFunc(run bool) { - s.isfuncs = run -} - -func (s *StarShell) SetFunc(funcs func(string)) { - s.funcs = funcs -} - -func (s *StarShell) Clear() { - defer s.rw.Unlock() - s.rw.Lock() - s.outbyte = []byte{} - s.errbyte = []byte{} - time.Sleep(time.Millisecond * 15) -} - -func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) { - defer s.Clear() - s.Clear() - return s.Shell(cmd, sleep) -} - -func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) { - if err := s.WriteCommand(cmd); err != nil { - return "", "", err - } - tmp1, tmp2, err := s.GetResult(sleep) - tmps := s.TrimColor(strings.TrimSpace(string(tmp1))) - if s.isecho { - n := len(strings.Split(cmd, "\n")) - if n == 1 { - list := strings.SplitN(tmps, "\n", 2) - if len(list) == 2 { - tmps = list[1] - } - } else { - list := strings.Split(tmps, "\n") - cmds := strings.Split(cmd, "\n") - for _, v := range cmds { - for k, v2 := range list { - if strings.TrimSpace(v2) == strings.TrimSpace(v) { - list[k] = "" - break - } - } - } - tmps = "" - for _, v := range list { - if v != "" { - tmps += v + "\n" - } - } - tmps = tmps[0 : len(tmps)-1] - } - } - return tmps, s.TrimColor(strings.TrimSpace(string(tmp2))), err -} - -func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) { - if s.errors != nil { - s.Session.Close() - return s.outbyte, s.errbyte, s.errors - } - if sleep > 0 { - time.Sleep(time.Millisecond * time.Duration(sleep)) - } - return s.outbyte, s.errbyte, nil -} - -func (s *StarShell) WriteCommand(cmd string) error { - return s.Write([]byte(cmd + "\n")) -} - -func (s *StarShell) Write(bstr []byte) error { - if s.errors != nil { - s.Session.Close() - return s.errors - } - _, err := s.in.Write(bstr) - return err -} - -func (s *StarShell) gohub() { - go func() { - var cache []byte - for { - read, err := s.er.ReadByte() - if err != nil { - s.errors = err - return - } - s.errbyte = append(s.errbyte, read) - if s.isprint { - fmt.Print(string([]byte{read})) - } - cache = append(cache, read) - if read == '\n' { - if s.isfuncs { - go s.funcs(s.TrimColor(strings.TrimSpace(string(cache)))) - cache = []byte{} - } - } - } - }() - var cache []byte - for { - read, err := s.out.ReadByte() - if err != nil { - s.errors = err - return - } - s.rw.Lock() - s.outbyte = append(s.outbyte, read) - cache = append(cache, read) - s.rw.Unlock() - if read == '\n' { - if s.isfuncs { - go s.funcs(strings.TrimSpace(string(cache))) - cache = []byte{} - } - } - if s.isprint { - fmt.Print(string([]byte{read})) - } - } -} - -func (s *StarShell) GetUid() string { - res, _, _ := s.ShellWait(`id | grep -oP "(?<=uid\=)\d+"`) - return strings.TrimSpace(res) -} -func (s *StarShell) GetGid() string { - res, _, _ := s.ShellWait(`id | grep -oP "(?<=gid\=)\d+"`) - return strings.TrimSpace(res) -} -func (s *StarShell) GetUser() string { - res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`) - return strings.TrimSpace(res) -} -func (s *StarShell) GetGroup() string { - res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`) - return strings.TrimSpace(res) -} - -func (s *StarSSH) NewShell() (shell *StarShell, err error) { - shell = new(StarShell) - shell.Session, err = s.NewSession() - if err != nil { - return - } - shell.in, _ = shell.Session.StdinPipe() - tmp, _ := shell.Session.StdoutPipe() - shell.out = bufio.NewReader(tmp) - tmp, _ = shell.Session.StderrPipe() - shell.er = bufio.NewReader(tmp) - err = shell.Session.Shell() - shell.isecho = true - go shell.Session.Wait() - shell.UseWaitDefault = true - shell.WriteCommand("bash") - time.Sleep(500 * time.Millisecond) - shell.WriteCommand("export PS1= ") - shell.WriteCommand("export PS2= ") - go shell.gohub() - time.Sleep(500 * time.Millisecond) - shell.Clear() - shell.Clear() - return -} - -func (s *StarSSH) Close() error { - if s.online { - return s.Client.Close() - } - return nil -} - -func (s *StarSSH) NewSession() (*ssh.Session, error) { - return NewSession(s.Client) -} - -func (s *StarSSH) ShellOne(cmd string) (string, error) { - newsess, err := s.NewSession() - if err != nil { - return "", err - } - data, err := newsess.CombinedOutput(cmd) - newsess.Close() - return strings.TrimSpace(string(data)), err -} - -func (s *StarSSH) Exists(filepath string) bool { - res, _ := s.ShellOne(`echo 1 && [ ! -e "` + filepath + `" ] && echo 2`) - if res == "1" { - return true - } else { - return false - } -} - -func (s *StarSSH) GetUid() string { - res, _ := s.ShellOne(`id | grep -oP "(?<=uid\=)\d+"`) - return strings.TrimSpace(res) -} -func (s *StarSSH) GetGid() string { - res, _ := s.ShellOne(`id | grep -oP "(?<=gid\=)\d+"`) - return strings.TrimSpace(res) -} -func (s *StarSSH) GetUser() string { - res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`) - return strings.TrimSpace(res) -} -func (s *StarSSH) GetGroup() string { - res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`) - return strings.TrimSpace(res) -} - -func (s *StarSSH) IsFile(filepath string) bool { - res, _ := s.ShellOne(`echo 1 && [ ! -f "` + filepath + `" ] && echo 2`) - if res == "1" { - return true - } else { - return false - } -} - -func (s *StarSSH) IsFolder(filepath string) bool { - res, _ := s.ShellOne(`echo 1 && [ ! -d "` + filepath + `" ] && echo 2`) - if res == "1" { - return true - } else { - return false - } -} - -func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) { - newsess, err := s.NewSession() - if err != nil { - return "", err - } - var bytes, errbytes []byte - tmp, _ := newsess.StdoutPipe() - reader := bufio.NewReader(tmp) - tmp, _ = newsess.StderrPipe() - errder := bufio.NewReader(tmp) - err = newsess.Start(cmd) - if err != nil { - return "", err - } - c := make(chan int, 1) - go newsess.Wait() - go func() { - for { - byt, err := reader.ReadByte() - if err != nil { - break - } - fmt.Print(string([]byte{byt})) - bytes = append(bytes, byt) - } - c <- 1 - }() - for { - byt, err := errder.ReadByte() - if err != nil { - break - } - fmt.Print(string([]byte{byt})) - errbytes = append(errbytes, byt) - } - _ = <-c - newsess.Close() - if len(errbytes) != 0 { - err = errors.New(strings.TrimSpace(string(errbytes))) - } else { - err = nil - } - return strings.TrimSpace(string(bytes)), err -} - -func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) { - newsess, err := s.NewSession() - if err != nil { - return "", err - } - var bytes, errbytes []byte - tmp, _ := newsess.StdoutPipe() - reader := bufio.NewReader(tmp) - tmp, _ = newsess.StderrPipe() - errder := bufio.NewReader(tmp) - err = newsess.Start(cmd) - if err != nil { - return "", err - } - c := make(chan int, 1) - go newsess.Wait() - go func() { - for { - byt, err := reader.ReadByte() - if err != nil { - break - } - callback(string([]byte{byt})) - bytes = append(bytes, byt) - } - c <- 1 - }() - for { - byt, err := errder.ReadByte() - if err != nil { - break - } - callback(string([]byte{byt})) - errbytes = append(errbytes, byt) - } - _ = <-c - newsess.Close() - if len(errbytes) != 0 { - err = errors.New(strings.TrimSpace(string(errbytes))) - } else { - err = nil - } - return strings.TrimSpace(string(bytes)), err -} - -func NewTransferSession(client *ssh.Client) (*ssh.Session, error) { - session, err := client.NewSession() - return session, err -} - -func NewSession(client *ssh.Client) (*ssh.Session, error) { - var session *ssh.Session - var err error - // create session - if session, err = client.NewSession(); err != nil { - return nil, err - } - modes := ssh.TerminalModes{ - ssh.ECHO: 1, // 还是要强制开启 - //ssh.IGNCR: 0, - ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud - ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud - } - - if err := session.RequestPty("xterm", 500, 250, modes); err != nil { - return nil, err - } - return session, nil -} - -func SedColor(str string) string { - reg := regexp.MustCompile(`\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]`) - //fmt.Println("regexp:", reg.Match([]byte(str))) - return string(reg.ReplaceAll([]byte(str), []byte(""))) -} diff --git a/sshagent_unix.go b/sshagent_unix.go new file mode 100644 index 0000000..e6c6fb6 --- /dev/null +++ b/sshagent_unix.go @@ -0,0 +1,21 @@ +//go:build !windows + +package starssh + +import ( + "net" + "os" + "strings" + "time" +) + +func dialSSHAgent(timeout time.Duration) (net.Conn, error) { + agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) + if agentSock == "" { + return nil, errSSHAgentUnavailable + } + if timeout > 0 { + return net.DialTimeout("unix", agentSock, timeout) + } + return net.Dial("unix", agentSock) +} diff --git a/sshagent_windows.go b/sshagent_windows.go new file mode 100644 index 0000000..f6b25b5 --- /dev/null +++ b/sshagent_windows.go @@ -0,0 +1,70 @@ +//go:build windows + +package starssh + +import ( + "context" + "errors" + "net" + "os" + "strings" + "time" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent` + +func dialSSHAgent(timeout time.Duration) (net.Conn, error) { + agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) + if agentSock != "" { + return dialWindowsSSHAgentEndpoint(agentSock, timeout) + } + return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true) +} + +func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) { + if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok { + return dialWindowsNamedPipe(pipePath, timeout, false) + } + if timeout > 0 { + return net.DialTimeout("unix", endpoint, timeout) + } + return net.Dial("unix", endpoint) +} + +func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) { + ctx := context.Background() + cancel := func() {} + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } + defer cancel() + + conn, err := winio.DialPipeContext(ctx, path) + if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) { + return nil, errSSHAgentUnavailable + } + return conn, err +} + +func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) { + trimmed := strings.TrimSpace(endpoint) + if trimmed == "" { + return "", false + } + + normalized := trimmed + if strings.HasPrefix(normalized, "//./pipe/") { + normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/") + } + if strings.HasPrefix(normalized, `\\.\pipe\`) { + return normalized, true + } + return "", false +} + +func isWindowsPipeUnavailable(err error) bool { + return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND) +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..c05afd4 --- /dev/null +++ b/state.go @@ -0,0 +1,106 @@ +package starssh + +import ( + "errors" + + "golang.org/x/crypto/ssh" +) + +type sshClientRequester interface { + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + Close() error +} + +var closeSSHClient = func(client sshClientRequester) error { + if client == nil { + return nil + } + return client.Close() +} + +func (s *StarSSH) snapshotSSHClient() *ssh.Client { + if s == nil { + return nil + } + + s.stateMu.RLock() + defer s.stateMu.RUnlock() + return s.Client +} + +func (s *StarSSH) requireSSHClient() (*ssh.Client, error) { + client := s.snapshotSSHClient() + if client == nil { + return nil, errors.New("ssh client is nil") + } + return client, nil +} + +func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) { + if s == nil { + return + } + + s.stateMu.Lock() + defer s.stateMu.Unlock() + s.Client = client + s.upstream = upstream + s.online = client != nil +} + +func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) { + if s == nil { + return nil, nil + } + + s.stateMu.Lock() + defer s.stateMu.Unlock() + + client := s.Client + upstream := s.upstream + s.Client = nil + s.upstream = nil + s.online = false + return client, upstream +} + +func (s *StarSSH) takeKeepaliveHandles() (chan struct{}, chan struct{}) { + if s == nil { + return nil, nil + } + + s.keepaliveMu.Lock() + defer s.keepaliveMu.Unlock() + + stop := s.keepaliveStop + done := s.keepaliveDone + s.keepaliveStop = nil + s.keepaliveDone = nil + return stop, done +} + +func (s *StarSSH) closeTransport(waitKeepalive bool) error { + if s == nil { + return nil + } + + _ = s.closeReusableSFTPClient() + + client, upstream := s.detachTransport() + stop, done := s.takeKeepaliveHandles() + if stop != nil { + close(stop) + } + + var closeErr error + if client != nil { + closeErr = normalizeAlreadyClosedError(closeSSHClient(client)) + } + if waitKeepalive && done != nil { + <-done + } + if upstreamErr := closeUpstream(upstream); closeErr == nil { + closeErr = upstreamErr + } + return closeErr +} diff --git a/terminal.go b/terminal.go new file mode 100644 index 0000000..64b7de1 --- /dev/null +++ b/terminal.go @@ -0,0 +1,431 @@ +package starssh + +import ( + "context" + "errors" + "io" + "sync" + + "golang.org/x/crypto/ssh" +) + +func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) { + session, err := s.NewPTYSession(config) + if err != nil { + return nil, err + } + + stdin, err := session.StdinPipe() + if err != nil { + _ = session.Close() + return nil, err + } + + stdout, err := session.StdoutPipe() + if err != nil { + _ = session.Close() + return nil, err + } + + stderr, err := session.StderrPipe() + if err != nil { + _ = session.Close() + return nil, err + } + + if err := session.Shell(); err != nil { + _ = session.Close() + return nil, err + } + + return &TerminalSession{ + Session: session, + stdin: stdin, + stdout: stdout, + stderr: stderr, + runDone: make(chan struct{}), + waitDone: make(chan struct{}), + }, nil +} + +func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) { + if t == nil { + return + } + + t.attachMu.Lock() + defer t.attachMu.Unlock() + t.in = stdin + t.out = stdout + t.errOut = stderr +} + +func (t *TerminalSession) StdinWriter() io.Writer { + if t == nil { + return nil + } + return t.stdin +} + +func (t *TerminalSession) StdoutReader() io.Reader { + if t == nil { + return nil + } + return t.stdout +} + +func (t *TerminalSession) StderrReader() io.Reader { + if t == nil { + return nil + } + return t.stderr +} + +func (t *TerminalSession) Write(data []byte) (int, error) { + if t == nil || t.stdin == nil { + return 0, errors.New("terminal stdin is not initialized") + } + return t.stdin.Write(data) +} + +func (t *TerminalSession) SendControl(control TerminalControl) error { + _, err := t.Write([]byte{byte(control)}) + return err +} + +func (t *TerminalSession) Interrupt() error { + return t.SendControl(TerminalControlInterrupt) +} + +func (t *TerminalSession) Signal(sig ssh.Signal) error { + if t == nil || t.Session == nil { + return errors.New("terminal session is not initialized") + } + if sig == "" { + return errors.New("signal is empty") + } + return t.Session.Signal(sig) +} + +func (t *TerminalSession) Run(ctx context.Context) error { + if t == nil { + return errors.New("terminal session is nil") + } + if ctx == nil { + ctx = context.Background() + } + + t.runOnce.Do(func() { + t.runErr = t.run(ctx) + close(t.runDone) + }) + + <-t.runDone + return t.runErr +} + +func (t *TerminalSession) run(ctx context.Context) error { + if t.Session == nil { + return errors.New("terminal session is not initialized") + } + + t.attachMu.RLock() + in := t.in + out := t.out + errOut := t.errOut + t.attachMu.RUnlock() + if out == nil { + out = io.Discard + } + if errOut == nil { + errOut = out + } + + inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in) + if err != nil { + return err + } + defer cancelInput() + + var copyWG sync.WaitGroup + doneCopy := make(chan struct{}) + copyWG.Add(2) + go func() { + defer copyWG.Done() + if t.stdout != nil { + _, _ = io.Copy(out, t.stdout) + } + }() + go func() { + defer copyWG.Done() + if t.stderr != nil { + _, _ = io.Copy(errOut, t.stderr) + } + }() + go func() { + copyWG.Wait() + close(doneCopy) + }() + + var doneInput chan struct{} + if inputReader != nil && t.stdin != nil { + doneInput = make(chan struct{}) + go func() { + defer close(doneInput) + _, _ = io.Copy(t.stdin, inputReader) + _ = t.stdin.Close() + }() + } + + waitInputPump := func() { + if doneInput == nil { + return + } + select { + case <-doneInput: + return + default: + } + + if inputCancelable { + cancelInput() + <-doneInput + } + } + + type waitResult struct { + info TerminalExitInfo + err error + } + + waitCh := make(chan waitResult, 1) + go func() { + info, err := t.WaitResult() + waitCh <- waitResult{info: info, err: err} + }() + + select { + case result := <-waitCh: + waitInputPump() + <-doneCopy + if result.err != nil { + return result.err + } + return result.info.CommandError() + case <-ctx.Done(): + t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err()) + cancelInput() + _ = t.Close() + <-waitCh + waitInputPump() + <-doneCopy + return ctx.Err() + } +} + +func (t *TerminalSession) Wait() error { + info, err := t.WaitResult() + if err != nil { + return err + } + return info.CommandError() +} + +func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) { + waitErr := t.waitRaw() + info, closeErr := t.snapshotExitState() + if closeErr != nil { + return info, closeErr + } + if waitErr == nil { + return info, nil + } + + var exitErr *ssh.ExitError + if errors.As(waitErr, &exitErr) { + return info, nil + } + if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed { + return info, nil + } + return info, waitErr +} + +func (t *TerminalSession) ExitInfo() TerminalExitInfo { + if t == nil { + return TerminalExitInfo{} + } + + t.stateMu.RLock() + defer t.stateMu.RUnlock() + return t.exitInfo +} + +func (info TerminalExitInfo) Success() bool { + return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == "" +} + +func (info TerminalExitInfo) CommandError() error { + if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal { + return nil + } + if info.ExitCode == 0 && info.ExitSignal == "" { + return nil + } + return &ExecExitError{ + Status: info.ExitCode, + Signal: info.ExitSignal, + Message: info.ExitMessage, + } +} + +func (t *TerminalSession) Resize(columns int, rows int) error { + if t == nil || t.Session == nil { + return errors.New("terminal session is not initialized") + } + if columns <= 0 || rows <= 0 { + return errors.New("columns and rows must be > 0") + } + return t.Session.WindowChange(rows, columns) +} + +func (t *TerminalSession) Close() error { + if t == nil { + return nil + } + + var closeErr error + t.closeOnce.Do(func() { + if t.stdin != nil { + _ = t.stdin.Close() + } + if t.Session != nil { + closeErr = normalizeAlreadyClosedError(t.Session.Close()) + } + }) + if closeErr != nil { + t.markCloseReason(TerminalCloseReasonTransportError, closeErr) + } + return closeErr +} + +func (t *TerminalSession) waitRaw() error { + if t == nil || t.Session == nil { + return errors.New("terminal session is not initialized") + } + + t.waitOnce.Do(func() { + go func() { + waitErr := t.Session.Wait() + t.setWaitResult(waitErr) + close(t.waitDone) + }() + }) + + <-t.waitDone + + t.stateMu.RLock() + defer t.stateMu.RUnlock() + return t.waitErr +} + +func (t *TerminalSession) setWaitResult(waitErr error) { + if t == nil { + return + } + + t.stateMu.Lock() + defer t.stateMu.Unlock() + t.waitErr = waitErr + t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason) +} + +func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) { + if t == nil || reason == TerminalCloseReasonUnknown { + return + } + + t.stateMu.Lock() + defer t.stateMu.Unlock() + + if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) { + t.closeReason = reason + } + if err != nil && t.closeErr == nil { + t.closeErr = err + } +} + +func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) { + if t == nil { + return TerminalExitInfo{}, nil + } + + t.stateMu.RLock() + defer t.stateMu.RUnlock() + return t.exitInfo, t.closeErr +} + +func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo { + info := TerminalExitInfo{} + + if waitErr == nil { + info.Reason = TerminalCloseReasonExit + } else { + var exitErr *ssh.ExitError + switch { + case errors.As(waitErr, &exitErr): + info.ExitCode = exitErr.ExitStatus() + info.ExitSignal = exitErr.Signal() + info.ExitMessage = exitErr.Msg() + if info.ExitSignal != "" { + info.Reason = TerminalCloseReasonSignal + } else { + info.Reason = TerminalCloseReasonExit + } + case normalizeAlreadyClosedError(waitErr) == nil: + info.Reason = TerminalCloseReasonClosed + case errors.Is(waitErr, context.Canceled): + info.Reason = TerminalCloseReasonContextCanceled + case errors.Is(waitErr, context.DeadlineExceeded): + info.Reason = TerminalCloseReasonDeadlineExceeded + default: + info.Reason = TerminalCloseReasonTransportError + } + } + + if overrideReason != TerminalCloseReasonUnknown && + terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) { + info.Reason = overrideReason + } + + return info +} + +func terminalCloseReasonFromErr(err error) TerminalCloseReason { + switch { + case errors.Is(err, context.DeadlineExceeded): + return TerminalCloseReasonDeadlineExceeded + case errors.Is(err, context.Canceled): + return TerminalCloseReasonContextCanceled + case err != nil: + return TerminalCloseReasonTransportError + default: + return TerminalCloseReasonUnknown + } +} + +func terminalCloseReasonPriority(reason TerminalCloseReason) int { + switch reason { + case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded: + return 30 + case TerminalCloseReasonTransportError: + return 20 + case TerminalCloseReasonClosed: + return 10 + case TerminalCloseReasonSignal, TerminalCloseReasonExit: + return 5 + default: + return 0 + } +} diff --git a/terminal_input.go b/terminal_input.go new file mode 100644 index 0000000..d9951e3 --- /dev/null +++ b/terminal_input.go @@ -0,0 +1,225 @@ +package starssh + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "reflect" + "sync" + "unsafe" +) + +// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader. +// Implementations that buffer data should return a source that already includes any prefetched bytes. +type TerminalInputSourceProvider interface { + TerminalInputSource() io.Reader +} + +// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook. +// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser. +type TerminalInputCanceler interface { + TerminalInputCancel() error +} + +// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source. +// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available. +type TerminalInputAdapter struct { + Reader io.Reader + Source io.Reader + Cancel func() error +} + +func (a TerminalInputAdapter) Read(p []byte) (int, error) { + if a.Reader == nil { + return 0, io.EOF + } + return a.Reader.Read(p) +} + +func (a TerminalInputAdapter) TerminalInputSource() io.Reader { + if a.Source != nil { + return a.Source + } + return a.Reader +} + +func (a TerminalInputAdapter) TerminalInputCancel() error { + if a.Cancel != nil { + return a.Cancel() + } + if closer, ok := a.Source.(io.Closer); ok && closer != nil { + return closer.Close() + } + if closer, ok := a.Reader.(io.Closer); ok && closer != nil { + return closer.Close() + } + return nil +} + +func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) { + if in == nil { + return nil, func() {}, false, nil + } + + var cancelOnce sync.Once + wrapCancel := func(fn func()) func() { + return func() { + cancelOnce.Do(fn) + } + } + + if provider, ok := in.(TerminalInputSourceProvider); ok { + source := provider.TerminalInputSource() + if source == nil || sameReader(source, in) { + return prepareDirectTerminalInputReader(in, wrapCancel) + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(source) + if err != nil { + return nil, nil, false, err + } + if canceler, ok := in.(TerminalInputCanceler); ok { + return prepared, wrapCancel(func() { + cancel() + _ = canceler.TerminalInputCancel() + }), true, nil + } + return prepared, cancel, cancelable, nil + } + + return prepareDirectTerminalInputReader(in, wrapCancel) +} + +func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) { + if in == nil { + return nil, func() {}, false, nil + } + + switch typed := in.(type) { + case *bufio.Reader: + return prepareBufferedTerminalInputReader(typed) + case *bufio.ReadWriter: + if typed.Reader == nil { + return in, func() {}, false, nil + } + return prepareBufferedTerminalInputReader(typed.Reader) + } + + if canceler, ok := in.(TerminalInputCanceler); ok { + return in, wrapCancel(func() { + _ = canceler.TerminalInputCancel() + }), true, nil + } + + if file, ok := in.(*os.File); ok { + dup, err := duplicateTerminalInputFile(file) + if err != nil { + return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err) + } + return dup, wrapCancel(func() { + _ = dup.Close() + }), true, nil + } + + if closer, ok := in.(io.ReadCloser); ok { + return closer, wrapCancel(func() { + _ = closer.Close() + }), true, nil + } + + return in, func() {}, false, nil +} + +func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) { + if reader == nil { + return nil, func() {}, false, nil + } + + bufferedPrefix, err := snapshotBufferedPrefix(reader) + if err != nil { + return nil, nil, false, err + } + + underlying := unwrapBufioReader(reader) + if underlying == nil { + if len(bufferedPrefix) == 0 { + return reader, func() {}, false, nil + } + return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying) + if err != nil { + return nil, nil, false, err + } + if len(bufferedPrefix) == 0 { + return prepared, cancel, cancelable, nil + } + if prepared == nil { + return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil + } + return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil +} + +func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) { + if reader == nil { + return nil, nil + } + + buffered := reader.Buffered() + if buffered == 0 { + return nil, nil + } + + chunk, err := reader.Peek(buffered) + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("peek terminal input buffer: %w", err) + } + prefix := append([]byte(nil), chunk...) + if _, err := reader.Discard(len(prefix)); err != nil { + return nil, fmt.Errorf("discard terminal input buffer: %w", err) + } + return prefix, nil +} + +func unwrapBufioReader(reader *bufio.Reader) io.Reader { + if reader == nil { + return nil + } + + value := reflect.ValueOf(reader) + if value.Kind() != reflect.Pointer || value.IsNil() { + return nil + } + + field := value.Elem().FieldByName("rd") + if !field.IsValid() { + return nil + } + + underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() + underlying, ok := underlyingValue.Interface().(io.Reader) + if !ok || underlying == nil || sameReader(underlying, reader) { + return nil + } + return underlying +} + +func sameReader(left io.Reader, right io.Reader) bool { + if left == nil || right == nil { + return false + } + + leftValue := reflect.ValueOf(left) + rightValue := reflect.ValueOf(right) + if !leftValue.IsValid() || !rightValue.IsValid() { + return false + } + if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer { + return false + } + return leftValue.Pointer() == rightValue.Pointer() +} diff --git a/terminal_input_adapter_test.go b/terminal_input_adapter_test.go new file mode 100644 index 0000000..a2e279c --- /dev/null +++ b/terminal_input_adapter_test.go @@ -0,0 +1,49 @@ +package starssh + +import ( + "bufio" + "io" + "testing" + "time" +) + +func TestPrepareTerminalInputReaderAdapterCancelUnblocksRead(t *testing.T) { + reader, writer := io.Pipe() + defer reader.Close() + defer writer.Close() + + adapter := TerminalInputAdapter{ + Reader: bufio.NewReader(reader), + Source: reader, + Cancel: func() error { + return reader.CloseWithError(io.ErrClosedPipe) + }, + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(adapter) + if err != nil { + t.Fatalf("prepare adapter input: %v", err) + } + if !cancelable { + t.Fatal("expected adapter-backed input to be cancelable") + } + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, readErr := prepared.Read(buf) + done <- readErr + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case readErr := <-done: + if readErr == nil { + t.Fatal("expected adapter cancel to interrupt blocking read") + } + case <-time.After(time.Second): + t.Fatal("blocking adapter input did not unblock after cancel") + } +} diff --git a/terminal_input_test.go b/terminal_input_test.go new file mode 100644 index 0000000..f3893b6 --- /dev/null +++ b/terminal_input_test.go @@ -0,0 +1,290 @@ +package starssh + +import ( + "bufio" + "bytes" + "io" + "os" + "testing" + "time" +) + +type terminalInputProvider struct { + io.Reader + source io.Reader +} + +func (p terminalInputProvider) TerminalInputSource() io.Reader { + return p.source +} + +type prefixedReadCloser struct { + io.Reader + io.Closer +} + +func TestPrepareTerminalInputReaderBufioReaderPreservesBufferedBytes(t *testing.T) { + reader, writer, err := os.Pipe() + if err != nil { + t.Fatalf("create pipe: %v", err) + } + defer reader.Close() + + if _, err := writer.Write([]byte("hello world")); err != nil { + writer.Close() + t.Fatalf("write pipe: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close writer: %v", err) + } + + buffered := bufio.NewReaderSize(reader, 4) + peeked, err := buffered.Peek(5) + if err != nil { + t.Fatalf("prime buffer: %v", err) + } + if string(peeked) != "hello" { + t.Fatalf("unexpected buffered prefix: %q", string(peeked)) + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered) + if err != nil { + t.Fatalf("prepare input: %v", err) + } + defer cancel() + + if !cancelable { + t.Fatal("expected cancelable reader") + } + + data, err := io.ReadAll(prepared) + if err != nil { + t.Fatalf("read prepared input: %v", err) + } + if string(data) != "hello world" { + t.Fatalf("unexpected prepared input: %q", string(data)) + } +} + +func TestPrepareTerminalInputReaderBufioReadWriterPreservesBufferedBytes(t *testing.T) { + reader, writer, err := os.Pipe() + if err != nil { + t.Fatalf("create pipe: %v", err) + } + defer reader.Close() + + if _, err := writer.Write([]byte("buffered payload")); err != nil { + writer.Close() + t.Fatalf("write pipe: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close writer: %v", err) + } + + readWriter := bufio.NewReadWriter(bufio.NewReaderSize(reader, 8), bufio.NewWriter(io.Discard)) + if _, err := readWriter.Reader.Peek(8); err != nil { + t.Fatalf("prime readwriter buffer: %v", err) + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter) + if err != nil { + t.Fatalf("prepare readwriter input: %v", err) + } + defer cancel() + + if !cancelable { + t.Fatal("expected cancelable readwriter input") + } + + data, err := io.ReadAll(prepared) + if err != nil { + t.Fatalf("read prepared readwriter input: %v", err) + } + if string(data) != "buffered payload" { + t.Fatalf("unexpected prepared readwriter input: %q", string(data)) + } +} + +func TestPrepareTerminalInputReaderBufioReaderFallbackKeepsData(t *testing.T) { + buffered := bufio.NewReader(bytes.NewBufferString("abc123")) + if _, err := buffered.Peek(3); err != nil { + t.Fatalf("prime buffer: %v", err) + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered) + if err != nil { + t.Fatalf("prepare fallback input: %v", err) + } + defer cancel() + + if cancelable { + t.Fatal("expected non-cancelable reader") + } + + data, err := io.ReadAll(prepared) + if err != nil { + t.Fatalf("read fallback input: %v", err) + } + if string(data) != "abc123" { + t.Fatalf("unexpected fallback input: %q", string(data)) + } +} + +func TestPrepareTerminalInputReaderProviderPrefersExplicitSource(t *testing.T) { + reader, writer, err := os.Pipe() + if err != nil { + t.Fatalf("create pipe: %v", err) + } + defer reader.Close() + + if _, err := writer.Write([]byte("provider data")); err != nil { + writer.Close() + t.Fatalf("write pipe: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close writer: %v", err) + } + + buffered := bufio.NewReader(reader) + prefix, err := buffered.Peek(len("provider data")) + if err != nil { + t.Fatalf("prime buffer: %v", err) + } + + source := prefixedReadCloser{ + Reader: io.MultiReader(bytes.NewReader(append([]byte(nil), prefix...)), reader), + Closer: reader, + } + provider := terminalInputProvider{ + Reader: buffered, + source: source, + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(provider) + if err != nil { + t.Fatalf("prepare provider input: %v", err) + } + defer cancel() + + if !cancelable { + t.Fatal("expected provider-backed input to be cancelable") + } + + data, err := io.ReadAll(prepared) + if err != nil { + t.Fatalf("read provider input: %v", err) + } + if string(data) != "provider data" { + t.Fatalf("unexpected provider input: %q", string(data)) + } +} + +func TestPrepareTerminalInputReaderProviderCancelUnblocksRead(t *testing.T) { + reader, writer := io.Pipe() + defer reader.Close() + defer writer.Close() + + buffered := bufio.NewReader(reader) + provider := terminalInputProvider{ + Reader: buffered, + source: reader, + } + + prepared, cancel, cancelable, err := prepareTerminalInputReader(provider) + if err != nil { + t.Fatalf("prepare input: %v", err) + } + + if !cancelable { + t.Fatal("expected provider-backed input to be cancelable") + } + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, readErr := prepared.Read(buf) + done <- readErr + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case readErr := <-done: + if readErr == nil { + t.Fatal("expected cancel to interrupt blocking read") + } + case <-time.After(time.Second): + t.Fatal("blocking read did not unblock after cancel") + } +} + +func TestPrepareTerminalInputReaderBufioReaderCancelUnblocksRead(t *testing.T) { + reader, writer := io.Pipe() + defer reader.Close() + defer writer.Close() + + buffered := bufio.NewReader(reader) + prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered) + if err != nil { + t.Fatalf("prepare input: %v", err) + } + + if !cancelable { + t.Fatal("expected bufio reader to be cancelable") + } + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, readErr := prepared.Read(buf) + done <- readErr + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case readErr := <-done: + if readErr == nil { + t.Fatal("expected cancel to interrupt blocking read") + } + case <-time.After(time.Second): + t.Fatal("blocking bufio reader did not unblock after cancel") + } +} + +func TestPrepareTerminalInputReaderBufioReadWriterCancelUnblocksRead(t *testing.T) { + reader, writer := io.Pipe() + defer reader.Close() + defer writer.Close() + + readWriter := bufio.NewReadWriter(bufio.NewReader(reader), bufio.NewWriter(io.Discard)) + prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter) + if err != nil { + t.Fatalf("prepare readwriter input: %v", err) + } + + if !cancelable { + t.Fatal("expected bufio readwriter to be cancelable") + } + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, readErr := prepared.Read(buf) + done <- readErr + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case readErr := <-done: + if readErr == nil { + t.Fatal("expected cancel to interrupt blocking read") + } + case <-time.After(time.Second): + t.Fatal("blocking readwriter input did not unblock after cancel") + } +} diff --git a/terminal_input_unix.go b/terminal_input_unix.go new file mode 100644 index 0000000..263af6d --- /dev/null +++ b/terminal_input_unix.go @@ -0,0 +1,21 @@ +//go:build !windows + +package starssh + +import ( + "os" + "syscall" +) + +func duplicateTerminalInputFile(file *os.File) (*os.File, error) { + if file == nil { + return nil, os.ErrInvalid + } + + fd, err := syscall.Dup(int(file.Fd())) + if err != nil { + return nil, err + } + syscall.CloseOnExec(fd) + return os.NewFile(uintptr(fd), file.Name()), nil +} diff --git a/terminal_input_windows.go b/terminal_input_windows.go new file mode 100644 index 0000000..539859b --- /dev/null +++ b/terminal_input_windows.go @@ -0,0 +1,32 @@ +//go:build windows + +package starssh + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func duplicateTerminalInputFile(file *os.File) (*os.File, error) { + if file == nil { + return nil, os.ErrInvalid + } + + currentProcess := windows.CurrentProcess() + var duplicated windows.Handle + err := windows.DuplicateHandle( + currentProcess, + windows.Handle(file.Fd()), + currentProcess, + &duplicated, + 0, + false, + windows.DUPLICATE_SAME_ACCESS, + ) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(duplicated), file.Name()), nil +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..92be42b --- /dev/null +++ b/transport.go @@ -0,0 +1,336 @@ +package starssh + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "time" +) + +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + if c == nil || c.reader == nil { + return 0, io.EOF + } + return c.reader.Read(p) +} + +func resolveDialContext(info LoginInput) DialContextFunc { + if info.DialContext != nil { + return info.DialContext + } + + dialer := &net.Dialer{ + Timeout: info.Timeout, + } + return dialer.DialContext +} + +func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) { + targetAddr := joinHostPort(info.Addr, info.Port) + if info.Jump != nil { + return dialViaJump(ctx, info, targetAddr) + } + + dialContext := resolveDialContext(info) + proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout) + if proxyConfig != nil { + return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr) + } + + conn, err := dialContext(ctx, "tcp", targetAddr) + return conn, nil, err +} + +func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) { + if info.Jump == nil { + return nil, nil, errors.New("jump login info is nil") + } + + jumpClient, err := loginWithContext(ctx, *info.Jump) + if err != nil { + return nil, nil, err + } + + conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close) + if err != nil { + _ = jumpClient.Close() + return nil, nil, err + } + return conn, jumpClient, nil +} + +func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) { + if dialContext == nil { + return nil, nil, errors.New("dial context is nil") + } + if strings.TrimSpace(proxy.Addr) == "" { + return nil, nil, errors.New("proxy address is empty") + } + + switch proxy.Type { + case ProxyTypeSOCKS5: + conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr) + return conn, nil, err + case ProxyTypeHTTPConnect: + conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr) + return conn, nil, err + default: + return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type) + } +} + +func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) { + conn, err := dialContext(ctx, "tcp", proxy.Addr) + if err != nil { + return nil, err + } + + restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout) + defer restoreDeadline() + + request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr) + if proxy.Username != "" || proxy.Password != "" { + token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password)) + request += "Proxy-Authorization: Basic " + token + "\r\n" + } + request += "\r\n" + + if _, err := io.WriteString(conn, request); err != nil { + _ = conn.Close() + return nil, err + } + + reader := bufio.NewReader(conn) + response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect}) + if err != nil { + _ = conn.Close() + return nil, err + } + defer response.Body.Close() + + if response.StatusCode < 200 || response.StatusCode >= 300 { + _, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024)) + _ = conn.Close() + return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status) + } + + if reader.Buffered() == 0 { + return conn, nil + } + return &bufferedConn{Conn: conn, reader: reader}, nil +} + +func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) { + conn, err := dialContext(ctx, "tcp", proxy.Addr) + if err != nil { + return nil, err + } + + restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout) + defer restoreDeadline() + + methods := []byte{0x00} + useAuth := proxy.Username != "" || proxy.Password != "" + if useAuth { + methods = append(methods, 0x02) + } + + hello := append([]byte{0x05, byte(len(methods))}, methods...) + if _, err := conn.Write(hello); err != nil { + _ = conn.Close() + return nil, err + } + + response := make([]byte, 2) + if _, err := io.ReadFull(conn, response); err != nil { + _ = conn.Close() + return nil, err + } + if response[0] != 0x05 { + _ = conn.Close() + return nil, fmt.Errorf("invalid socks5 version %d", response[0]) + } + if response[1] == 0xFF { + _ = conn.Close() + return nil, errors.New("socks5 proxy has no acceptable auth method") + } + + if response[1] == 0x02 { + if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil { + _ = conn.Close() + return nil, err + } + } + + if err := writeSOCKS5Connect(conn, targetAddr); err != nil { + _ = conn.Close() + return nil, err + } + + if err := readSOCKS5ConnectResponse(conn); err != nil { + _ = conn.Close() + return nil, err + } + + return conn, nil +} + +func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error { + if len(username) > 255 || len(password) > 255 { + return errors.New("socks5 username/password too long") + } + + request := make([]byte, 0, 3+len(username)+len(password)) + request = append(request, 0x01, byte(len(username))) + request = append(request, []byte(username)...) + request = append(request, byte(len(password))) + request = append(request, []byte(password)...) + + if _, err := conn.Write(request); err != nil { + return err + } + + response := make([]byte, 2) + if _, err := io.ReadFull(conn, response); err != nil { + return err + } + if response[1] != 0x00 { + return errors.New("socks5 username/password authentication failed") + } + return nil +} + +func writeSOCKS5Connect(conn net.Conn, targetAddr string) error { + host, portString, err := net.SplitHostPort(targetAddr) + if err != nil { + return err + } + + port, err := strconv.Atoi(portString) + if err != nil { + return err + } + if port < 0 || port > 65535 { + return fmt.Errorf("invalid port %d", port) + } + + request := []byte{0x05, 0x01, 0x00} + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + request = append(request, 0x01) + request = append(request, ip4...) + } else { + request = append(request, 0x04) + request = append(request, ip.To16()...) + } + } else { + if len(host) > 255 { + return errors.New("socks5 target host too long") + } + request = append(request, 0x03, byte(len(host))) + request = append(request, []byte(host)...) + } + + request = append(request, byte(port>>8), byte(port)) + _, err = conn.Write(request) + return err +} + +func readSOCKS5ConnectResponse(conn net.Conn) error { + header := make([]byte, 4) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + if header[0] != 0x05 { + return fmt.Errorf("invalid socks5 response version %d", header[0]) + } + if header[1] != 0x00 { + return fmt.Errorf("socks5 connect failed with code %d", header[1]) + } + + switch header[3] { + case 0x01: + if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil { + return err + } + case 0x03: + size := make([]byte, 1) + if _, err := io.ReadFull(conn, size); err != nil { + return err + } + if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil { + return err + } + case 0x04: + if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil { + return err + } + default: + return fmt.Errorf("unsupported socks5 bind address type %d", header[3]) + } + + _, err := io.ReadFull(conn, make([]byte, 2)) + return err +} + +func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() { + if conn == nil { + return func() {} + } + + var ( + deadline time.Time + hasValue bool + ) + if ctx != nil { + if ctxDeadline, ok := ctx.Deadline(); ok { + deadline = ctxDeadline + hasValue = true + } + } + if timeout > 0 { + timeoutDeadline := time.Now().Add(timeout) + if !hasValue || timeoutDeadline.Before(deadline) { + deadline = timeoutDeadline + hasValue = true + } + } + if !hasValue { + return func() {} + } + + _ = conn.SetDeadline(deadline) + return func() { + _ = conn.SetDeadline(time.Time{}) + } +} + +func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig { + if proxy == nil { + return nil + } + + normalized := *proxy + normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type)))) + normalized.Addr = strings.TrimSpace(normalized.Addr) + if normalized.Timeout <= 0 { + normalized.Timeout = defaultTimeout + } + return &normalized +} + +func joinHostPort(host string, port int) string { + return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port)) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..7e2fe35 --- /dev/null +++ b/types.go @@ -0,0 +1,196 @@ +package starssh + +import ( + "bufio" + "context" + "io" + "net" + "sync" + "time" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +const ( + defaultSSHPort = 22 + defaultLoginTimeout = 5 * time.Second + defaultKeepAliveTimeout = 3 * time.Second + defaultShellPollInterval = 120 * time.Millisecond + defaultShellSetupDelay = 200 * time.Millisecond + defaultShellSetupTimeout = 3 * time.Second + defaultShellWaitTimeout = 30 * time.Second + defaultShellPromptToken = "__STARSSH_PROMPT__>" + defaultPTYTerm = "xterm" + defaultPTYRows = 500 + defaultPTYColumns = 250 + + defaultTransferBufferSize = 1024 * 1024 + + defaultExecStreamMaxPendingChunks = 256 + defaultExecStreamMaxPendingBytes = 4 * 1024 * 1024 +) + +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +type ProxyType string + +const ( + ProxyTypeSOCKS5 ProxyType = "socks5" + ProxyTypeHTTPConnect ProxyType = "http_connect" +) + +type ProxyConfig struct { + Type ProxyType + Addr string + Username string + Password string + Timeout time.Duration +} + +type AuthMethodKind string + +const ( + AuthMethodPrivateKey AuthMethodKind = "private_key" + AuthMethodPassword AuthMethodKind = "password" + AuthMethodKeyboardInteractive AuthMethodKind = "keyboard_interactive" + AuthMethodSSHAgent AuthMethodKind = "ssh_agent" +) + +type StarSSH struct { + stateMu sync.RWMutex + Client *ssh.Client + PublicKey ssh.PublicKey + PubkeyBase64 string + Hostname string + RemoteAddr net.Addr + Banner string + LoginInfo LoginInput + online bool + upstream *StarSSH + sftpClient *sftp.Client + sftpMu sync.Mutex + keepaliveMu sync.Mutex + keepaliveStop chan struct{} + keepaliveDone chan struct{} +} + +type LoginInput struct { + KeyExchanges []string + Ciphers []string + MACs []string + User string + Password string + PasswordCallback func() (string, error) + KeyboardInteractiveCallback ssh.KeyboardInteractiveChallenge + Prikey string + PrikeyPwd string + DisableSSHAgent bool + AuthOrder []AuthMethodKind + Addr string + Port int + Timeout time.Duration + DialContext DialContextFunc + Proxy *ProxyConfig + Jump *LoginInput + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration + HostKeyCallback func(string, net.Addr, ssh.PublicKey) error + BannerCallback func(string) error +} + +// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions. +// It is not a generic cross-shell abstraction; for product-grade interactive terminals, prefer TerminalSession. +type StarShell struct { + Keyword string + UseWaitDefault bool + WaitTimeout time.Duration + Session *ssh.Session + in io.Writer + out *bufio.Reader + er *bufio.Reader + outbyte []byte + errbyte []byte + lastout int64 + errors error + isprint bool + isfuncs bool + iscolor bool + isecho bool + rw sync.RWMutex + funcs func(string) + writeMu sync.Mutex + commandMu sync.Mutex + closeOnce sync.Once + promptToken string +} + +type TerminalConfig struct { + Term string + Rows int + Columns int + Modes ssh.TerminalModes +} + +type TerminalControl byte + +const ( + TerminalControlInterrupt TerminalControl = 0x03 + TerminalControlEOF TerminalControl = 0x04 + TerminalControlBell TerminalControl = 0x07 + TerminalControlBackspace TerminalControl = 0x08 + TerminalControlLineKill TerminalControl = 0x15 + TerminalControlQuit TerminalControl = 0x1c + TerminalControlSuspend TerminalControl = 0x1a + TerminalControlPauseOutput TerminalControl = 0x13 + TerminalControlResumeOutput TerminalControl = 0x11 +) + +type TerminalCloseReason string + +const ( + TerminalCloseReasonUnknown TerminalCloseReason = "" + TerminalCloseReasonExit TerminalCloseReason = "exit" + TerminalCloseReasonSignal TerminalCloseReason = "signal" + TerminalCloseReasonClosed TerminalCloseReason = "closed" + TerminalCloseReasonContextCanceled TerminalCloseReason = "context_canceled" + TerminalCloseReasonDeadlineExceeded TerminalCloseReason = "deadline_exceeded" + TerminalCloseReasonTransportError TerminalCloseReason = "transport_error" +) + +type TerminalExitInfo struct { + ExitCode int + ExitSignal string + ExitMessage string + Reason TerminalCloseReason +} + +type TerminalSession struct { + Session *ssh.Session + ID string + Label string + Metadata map[string]string + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + + attachMu sync.RWMutex + in io.Reader + out io.Writer + errOut io.Writer + + runOnce sync.Once + runDone chan struct{} + runErr error + + waitOnce sync.Once + waitDone chan struct{} + + stateMu sync.RWMutex + waitErr error + exitInfo TerminalExitInfo + closeReason TerminalCloseReason + closeErr error + + closeOnce sync.Once +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..ef47880 --- /dev/null +++ b/utils.go @@ -0,0 +1,162 @@ +package starssh + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "regexp" + "strconv" + "strings" + "time" +) + +var ( + ansiCSIRegexp = regexp.MustCompile(`\x1b\[[0-9;?]*[ -/]*[@-~]`) + ansiOSCRegexp = regexp.MustCompile(`\x1b\][^\x07]*(\x07|\x1b\\)`) + leadingIntRegexp = regexp.MustCompile(`^[+-]?\d+`) +) + +func SedColor(str string) string { + return stripControlSequences(str) +} + +func normalizeShellOutput(raw string) string { + return strings.TrimSpace(strings.ReplaceAll(raw, "\r\n", "\n")) +} + +func stripControlSequences(raw string) string { + cleaned := ansiOSCRegexp.ReplaceAllString(raw, "") + cleaned = ansiCSIRegexp.ReplaceAllString(cleaned, "") + cleaned = strings.ReplaceAll(cleaned, "\r", "") + cleaned = strings.Map(func(r rune) rune { + if r == '\n' || r == '\t' { + return r + } + if r < 0x20 || r == 0x7f { + return -1 + } + return r + }, cleaned) + return cleaned +} + +func stripLeadingEcho(output string, command string, markerCommand string) string { + result := output + if command == "" { + return result + } + + withMarker := command + if markerCommand != "" { + withMarker += "\n" + markerCommand + } + if strings.HasPrefix(result, withMarker) { + return strings.TrimSpace(strings.TrimPrefix(result, withMarker)) + } + if strings.HasPrefix(result, command) { + return strings.TrimSpace(strings.TrimPrefix(result, command)) + } + return result +} + +func collectLinesWithoutTokens(output string, tokens ...string) string { + lines := strings.Split(output, "\n") + filtered := make([]string, 0, len(lines)) + + for _, line := range lines { + skip := false + for _, token := range tokens { + if token != "" && strings.Contains(line, token) { + skip = true + break + } + } + if !skip { + filtered = append(filtered, line) + } + } + + return strings.TrimSpace(strings.Join(filtered, "\n")) +} + +func collectLinesForCommandOutput(output string, promptToken string, tokens ...string) string { + lines := strings.Split(output, "\n") + filtered := make([]string, 0, len(lines)) + + for _, line := range lines { + skip := false + for _, token := range tokens { + if token != "" && strings.Contains(line, token) { + skip = true + break + } + } + if !skip && promptToken != "" && strings.Contains(line, promptToken) { + skip = true + } + if !skip { + filtered = append(filtered, line) + } + } + + return strings.TrimSpace(strings.Join(filtered, "\n")) +} + +func shellSingleQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" +} + +func normalizeBufferSize(bufcap int) int { + if bufcap <= 0 { + return defaultTransferBufferSize + } + return bufcap +} + +func newCommandTokens() (beginToken string, endToken string) { + nonce := newNonce(8) + return "__STARSSH_BEGIN_" + nonce + "__", "__STARSSH_END_" + nonce + "__" +} + +func newNonce(size int) string { + if size <= 0 { + size = 8 + } + + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return strings.ToUpper(hex.EncodeToString(buf)) +} + +func splitByEndToken(output string, endToken string) (before string, exitCode int, found bool, parseErr error) { + prefix := endToken + ":" + lines := strings.Split(output, "\n") + beforeLines := make([]string, 0, len(lines)) + + for _, line := range lines { + trimmedLine := strings.TrimSpace(line) + if !strings.HasPrefix(trimmedLine, prefix) { + beforeLines = append(beforeLines, line) + continue + } + + codeText := strings.TrimSpace(trimmedLine[len(prefix):]) + match := leadingIntRegexp.FindString(codeText) + if match == "" || match != codeText { + beforeLines = append(beforeLines, line) + continue + } + + code, err := strconv.Atoi(match) + if err != nil { + beforeLines = append(beforeLines, line) + continue + } + + return strings.Join(beforeLines, "\n"), code, true, nil + } + + return strings.Join(beforeLines, "\n"), 0, false, nil +}