Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
672a111ec1
|
|||
|
0c23e7d4bf
|
|||
|
ad7c8b0587
|
|||
|
1625997d8f
|
|||
|
b29246a9c4
|
|||
|
f20eb653ae
|
|||
|
d6fbea8468
|
@@ -0,0 +1,9 @@
|
||||
.sentrux/
|
||||
agent_readme.md
|
||||
target.md
|
||||
.gocache/
|
||||
.tmp_*/
|
||||
.codex/
|
||||
.idea/
|
||||
agents.md
|
||||
.codex
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,454 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
sshagent "golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return sshagent.RequestAgentForwarding(session)
|
||||
}
|
||||
|
||||
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||
|
||||
var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return startSSHAgentForwardProxy(client, timeouts)
|
||||
}
|
||||
|
||||
var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts)
|
||||
if err != nil {
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if conn == nil {
|
||||
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||
|
||||
type sshAgentForwardProxy struct {
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
|
||||
activeMu sync.Mutex
|
||||
active map[*sshAgentForwardBridge]struct{}
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) Close() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
p.stopOnce.Do(func() {
|
||||
close(p.stopCh)
|
||||
})
|
||||
p.closeActive()
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshAgentForwardBridge struct {
|
||||
proxy *sshAgentForwardProxy
|
||||
channel ssh.Channel
|
||||
conn net.Conn
|
||||
idleTimeout time.Duration
|
||||
|
||||
closeOnce sync.Once
|
||||
signalOnce sync.Once
|
||||
done chan struct{}
|
||||
activity chan struct{}
|
||||
}
|
||||
|
||||
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||
if s == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
}
|
||||
if session == nil {
|
||||
return errors.New("ssh session is nil")
|
||||
}
|
||||
if err := s.ensureAgentForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requestSSHAgentForwarding(session); err != nil {
|
||||
if isSSHAgentForwardingDeniedError(err) {
|
||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
|
||||
if s == nil || !s.LoginInfo.ForwardSSHAgent {
|
||||
return nil
|
||||
}
|
||||
err := s.RequestAgentForwarding(session)
|
||||
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarSSH) ensureAgentForwarding() error {
|
||||
if s == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
}
|
||||
|
||||
s.agentForwardMu.Lock()
|
||||
defer s.agentForwardMu.Unlock()
|
||||
|
||||
if s.agentForwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeouts := effectiveSSHAgentTimeouts(s.LoginInfo)
|
||||
if err := probeSSHAgentForwarding(timeouts); err != nil {
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if s.closing.Load() {
|
||||
return errSSHClientClosing
|
||||
}
|
||||
closer, err := routeSSHAgentForwarding(client, timeouts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !s.canAttachAgentForwarder(client) {
|
||||
if closer != nil {
|
||||
_ = closer.Close()
|
||||
}
|
||||
return errSSHClientClosing
|
||||
}
|
||||
s.agentForwarder = closer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) takeAgentForwarder() io.Closer {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.agentForwardMu.Lock()
|
||||
defer s.agentForwardMu.Unlock()
|
||||
|
||||
closer := s.agentForwarder
|
||||
s.agentForwarder = nil
|
||||
return closer
|
||||
}
|
||||
|
||||
func isSSHAgentForwardingDeniedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, errSSHAgentForwardingDenied) {
|
||||
return true
|
||||
}
|
||||
message := strings.ToLower(err.Error())
|
||||
return strings.Contains(message, "forwarding request denied") ||
|
||||
strings.Contains(message, "agent forwarding disabled")
|
||||
}
|
||||
|
||||
func isSSHAgentForwardingUnavailableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||||
return true
|
||||
}
|
||||
message := strings.ToLower(err.Error())
|
||||
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
|
||||
strings.Contains(message, "ssh-agent unavailable")
|
||||
}
|
||||
|
||||
func wrapSSHAgentForwardingUnavailable(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, errSSHAgentUnavailable) {
|
||||
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||
}
|
||||
|
||||
func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||
if channels == nil {
|
||||
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||
}
|
||||
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-proxy.stopCh:
|
||||
return
|
||||
case ch, ok := <-channels:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
go handleSSHAgentForwardChannel(proxy, ch, timeouts)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts)
|
||||
if err != nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||
return
|
||||
}
|
||||
channel, reqs, err := ch.Accept()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
bridge := &sshAgentForwardBridge{
|
||||
proxy: proxy,
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
idleTimeout: timeouts.Forward,
|
||||
}
|
||||
if !proxy.registerBridge(bridge) {
|
||||
bridge.close()
|
||||
return
|
||||
}
|
||||
go bridge.run()
|
||||
}
|
||||
|
||||
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||
bridge := &sshAgentForwardBridge{
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
}
|
||||
bridge.run()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) run() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.ensureSignals()
|
||||
stopWatchdog := b.startIdleWatchdog()
|
||||
defer stopWatchdog()
|
||||
defer b.unregister()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(
|
||||
sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch},
|
||||
sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch},
|
||||
)
|
||||
b.close()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(
|
||||
sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch},
|
||||
sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch},
|
||||
)
|
||||
b.close()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) close() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.closeOnce.Do(func() {
|
||||
b.ensureSignals()
|
||||
close(b.done)
|
||||
closeWriter(b.channel)
|
||||
closeWriter(b.conn)
|
||||
if b.channel != nil {
|
||||
_ = b.channel.Close()
|
||||
}
|
||||
if b.conn != nil {
|
||||
_ = b.conn.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) ensureSignals() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.signalOnce.Do(func() {
|
||||
b.done = make(chan struct{})
|
||||
b.activity = make(chan struct{}, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) startIdleWatchdog() func() {
|
||||
if b == nil || b.idleTimeout <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
b.ensureSignals()
|
||||
timer := time.NewTimer(b.idleTimeout)
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
b.close()
|
||||
return
|
||||
case <-b.activity:
|
||||
resetTimer(timer, b.idleTimeout)
|
||||
case <-b.done:
|
||||
return
|
||||
case <-stopped:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(stopped)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) touch() {
|
||||
if b == nil || b.idleTimeout <= 0 || b.activity == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case b.activity <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
type sshAgentForwardActivityReader struct {
|
||||
io.Reader
|
||||
touch func()
|
||||
}
|
||||
|
||||
func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) {
|
||||
n, err := r.Reader.Read(p)
|
||||
if n > 0 && r.touch != nil {
|
||||
r.touch()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
type sshAgentForwardActivityWriter struct {
|
||||
io.Writer
|
||||
touch func()
|
||||
}
|
||||
|
||||
func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) {
|
||||
n, err := w.Writer.Write(p)
|
||||
if n > 0 && w.touch != nil {
|
||||
w.touch()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func resetTimer(timer *time.Timer, timeout time.Duration) {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(timeout)
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) unregister() {
|
||||
if b == nil || b.proxy == nil {
|
||||
return
|
||||
}
|
||||
b.proxy.unregisterBridge(b)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||
if p == nil || bridge == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
if p.active == nil {
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
}
|
||||
p.active[bridge] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||
if p == nil || bridge == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
delete(p.active, bridge)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) closeActive() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||
for bridge := range p.active {
|
||||
active = append(active, bridge)
|
||||
}
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
p.activeMu.Unlock()
|
||||
|
||||
for _, bridge := range active {
|
||||
bridge.close()
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(value any) {
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
if cw, ok := value.(closeWriter); ok {
|
||||
_ = cw.CloseWrite()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,665 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type testCloser struct {
|
||||
closed atomic.Int32
|
||||
}
|
||||
|
||||
func (c *testCloser) Close() error {
|
||||
c.closed.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
closed atomic.Int32
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.closed.Add(1)
|
||||
if c.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type testSSHChannel struct {
|
||||
readFunc func([]byte) (int, error)
|
||||
|
||||
stderr bytes.Buffer
|
||||
closed atomic.Int32
|
||||
closeOnce sync.Once
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type testNewChannel struct {
|
||||
channel ssh.Channel
|
||||
accepted atomic.Bool
|
||||
rejected atomic.Bool
|
||||
}
|
||||
|
||||
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
c.accepted.Store(true)
|
||||
requests := make(chan *ssh.Request)
|
||||
close(requests)
|
||||
return c.channel, requests, nil
|
||||
}
|
||||
|
||||
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
|
||||
c.rejected.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testNewChannel) ChannelType() string {
|
||||
return sshAgentChannelType
|
||||
}
|
||||
|
||||
func (c *testNewChannel) ExtraData() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||
return &testSSHChannel{
|
||||
readFunc: readFunc,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func newBlockingTestSSHChannel() *testSSHChannel {
|
||||
ch := newTestSSHChannel(nil)
|
||||
ch.readFunc = func(p []byte) (int, error) {
|
||||
<-ch.closeCh
|
||||
return 0, io.EOF
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if c.readFunc != nil {
|
||||
return c.readFunc(p)
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.closeOnce.Do(func() {
|
||||
c.closed.Add(1)
|
||||
close(c.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
||||
return &c.stderr
|
||||
}
|
||||
|
||||
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
Timeout: time.Second,
|
||||
SSHAgentTimeout: 3 * time.Second,
|
||||
SSHAgentForwardTimeout: 4 * time.Second,
|
||||
},
|
||||
}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
|
||||
var probeCalls atomic.Int32
|
||||
closer := &testCloser{}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
probeCalls.Add(1)
|
||||
if timeouts.Dial != time.Second {
|
||||
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
|
||||
}
|
||||
if timeouts.Operation != 3*time.Second {
|
||||
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
|
||||
}
|
||||
if timeouts.Forward != 4*time.Second {
|
||||
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var routeCalls atomic.Int32
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
routeCalls.Add(1)
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected routed client %p", client)
|
||||
}
|
||||
if timeouts.Dial != time.Second {
|
||||
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
|
||||
}
|
||||
if timeouts.Operation != 3*time.Second {
|
||||
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
|
||||
}
|
||||
if timeouts.Forward != 4*time.Second {
|
||||
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
|
||||
}
|
||||
return closer, nil
|
||||
}
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
requestCalls.Add(1)
|
||||
if session == nil {
|
||||
t.Fatal("expected non-nil ssh session")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("first exec session: %v", err)
|
||||
}
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("second exec session: %v", err)
|
||||
}
|
||||
|
||||
if probeCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
||||
}
|
||||
if routeCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||
}
|
||||
if requestCalls.Load() != 2 {
|
||||
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
|
||||
}
|
||||
|
||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||
if err := star.Close(); err != nil {
|
||||
t.Fatalf("close starssh: %v", err)
|
||||
}
|
||||
if closer.closed.Load() != 1 {
|
||||
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldRequestSessionPTY := requestSessionPTY
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
requestSessionPTY = oldRequestSessionPTY
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
|
||||
var ptyCalls atomic.Int32
|
||||
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
||||
ptyCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
requestCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewPTYSession(nil); err != nil {
|
||||
t.Fatalf("new pty session: %v", err)
|
||||
}
|
||||
if ptyCalls.Load() != 1 {
|
||||
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
|
||||
}
|
||||
if requestCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
t.Fatal("agent forwarding probe should not run when disabled")
|
||||
return nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("agent forwarding should not be requested when disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session without forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("session request should not run when agent forwarder init fails")
|
||||
return nil
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if err == nil {
|
||||
t.Fatal("expected agent forwarding init error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if !isSSHAgentForwardingUnavailableError(err) {
|
||||
t.Fatalf("expected unavailable error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if !isSSHAgentForwardingDeniedError(err) {
|
||||
t.Fatalf("expected forwarding denied error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore agent setup error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
closer := &testCloser{}
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return closer, nil
|
||||
}
|
||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- star.ensureAgentForwarding()
|
||||
}()
|
||||
|
||||
<-started
|
||||
closeDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = star.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for !star.closing.Load() {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("close did not enter closing state in time")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
close(release)
|
||||
|
||||
err := <-errCh
|
||||
if !errors.Is(err, errSSHClientClosing) {
|
||||
t.Fatalf("expected closing error, got %v", err)
|
||||
}
|
||||
<-closeDone
|
||||
|
||||
if closer.closed.Load() != 1 {
|
||||
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
|
||||
}
|
||||
if got := star.takeAgentForwarder(); got != nil {
|
||||
t.Fatal("expected no leaked agent forwarder after close race")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
proxySSHAgentChannel(channel, tracked)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
||||
}
|
||||
|
||||
if tracked.closed.Load() == 0 {
|
||||
t.Fatal("expected local agent connection to be closed")
|
||||
}
|
||||
if channel.closed.Load() == 0 {
|
||||
t.Fatal("expected ssh channel to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
channel := newBlockingTestSSHChannel()
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
bridge := &sshAgentForwardBridge{
|
||||
proxy: proxy,
|
||||
channel: channel,
|
||||
conn: tracked,
|
||||
}
|
||||
if !proxy.registerBridge(bridge) {
|
||||
t.Fatal("expected bridge registration to succeed")
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bridge.run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if err := proxy.Close(); err != nil {
|
||||
t.Fatalf("close proxy: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("bridge did not exit after proxy close")
|
||||
}
|
||||
|
||||
if tracked.closed.Load() == 0 {
|
||||
t.Fatal("expected proxy close to close local agent connection")
|
||||
}
|
||||
if channel.closed.Load() == 0 {
|
||||
t.Fatal("expected proxy close to close ssh channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
|
||||
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
||||
t.Cleanup(func() {
|
||||
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
||||
})
|
||||
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||
return tracked, nil
|
||||
}
|
||||
|
||||
channel := newBlockingTestSSHChannel()
|
||||
newChannel := &testNewChannel{
|
||||
channel: channel,
|
||||
}
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
|
||||
Endpoint: "/tmp/agent.sock",
|
||||
Forward: 20 * time.Millisecond,
|
||||
})
|
||||
|
||||
if !newChannel.accepted.Load() {
|
||||
t.Fatal("expected channel to be accepted")
|
||||
}
|
||||
|
||||
waitUntil(t, time.Second, func() bool {
|
||||
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
|
||||
}, "forwarded agent bridge did not close both sides after idle timeout")
|
||||
|
||||
waitUntil(t, time.Second, func() bool {
|
||||
proxy.activeMu.Lock()
|
||||
defer proxy.activeMu.Unlock()
|
||||
return len(proxy.active) == 0
|
||||
}, "forwarded agent bridge did not unregister after idle timeout")
|
||||
}
|
||||
|
||||
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
t.Fatal(message)
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(client, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := star.PingContext(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 0 {
|
||||
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != client {
|
||||
t.Fatal("expected ssh client to remain attached after PingContext cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := star.PingContextCloseOnCancel(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(client, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil connection on canceled dial")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 0 {
|
||||
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != client {
|
||||
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil connection on canceled dial")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExistsFallsBackFromPOSIXToPowerShell(t *testing.T) {
|
||||
oldExecRequestRunner := execRequestRunner
|
||||
t.Cleanup(func() {
|
||||
execRequestRunner = oldExecRequestRunner
|
||||
})
|
||||
|
||||
var calls []ExecRequest
|
||||
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
calls = append(calls, req)
|
||||
switch len(calls) {
|
||||
case 1:
|
||||
return &ExecResult{ExitCode: 127, Stderr: []byte("sh not found")}, nil
|
||||
case 2:
|
||||
return &ExecResult{Stdout: []byte("1\n")}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected extra probe request: %+v", req)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
if !star.Exists(`C:\Windows\System32`) {
|
||||
t.Fatal("expected helper to succeed after PowerShell fallback")
|
||||
}
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("expected two probe attempts, got %d", len(calls))
|
||||
}
|
||||
if calls[0].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(calls[0].Command, "sh -lc ") {
|
||||
t.Fatalf("unexpected first probe request: %+v", calls[0])
|
||||
}
|
||||
if calls[1].ShellDialect != ExecShellDialectRaw || !strings.HasPrefix(strings.ToLower(calls[1].Command), "powershell.exe ") {
|
||||
t.Fatalf("unexpected second probe request: %+v", calls[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserFallsBackToCMDWhenPowerShellVariantsFail(t *testing.T) {
|
||||
oldExecRequestRunner := execRequestRunner
|
||||
t.Cleanup(func() {
|
||||
execRequestRunner = oldExecRequestRunner
|
||||
})
|
||||
|
||||
var calls []ExecRequest
|
||||
execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
calls = append(calls, req)
|
||||
if len(calls) < 5 {
|
||||
return &ExecResult{ExitCode: 127, Stderr: []byte("command not found")}, nil
|
||||
}
|
||||
return &ExecResult{Stdout: []byte("HOST\\tester\r\n")}, nil
|
||||
}
|
||||
|
||||
star := &StarSSH{}
|
||||
if got := star.GetUser(); got != `HOST\tester` {
|
||||
t.Fatalf("unexpected user after fallback: %q", got)
|
||||
}
|
||||
if len(calls) != 5 {
|
||||
t.Fatalf("expected five probe attempts, got %d", len(calls))
|
||||
}
|
||||
if got := strings.ToLower(calls[4].Command); got != "cmd.exe /q /d /c whoami" {
|
||||
t.Fatalf("unexpected final fallback command: %q", calls[4].Command)
|
||||
}
|
||||
}
|
||||
+952
@@ -0,0 +1,952 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ForwardRequest struct {
|
||||
// Keep the exported shape compatible with older positional literals:
|
||||
// ForwardRequest{listenAddr, targetAddr, dialContext}.
|
||||
//
|
||||
// Non-default networks can be encoded with an explicit scheme-like prefix:
|
||||
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
|
||||
// Bare values default to the "tcp" network.
|
||||
ListenAddr string
|
||||
TargetAddr string
|
||||
DialContext DialContextFunc
|
||||
}
|
||||
|
||||
type normalizedForwardRequest struct {
|
||||
ListenNetwork string
|
||||
ListenAddr string
|
||||
TargetNetwork string
|
||||
TargetAddr string
|
||||
DialContext DialContextFunc
|
||||
}
|
||||
|
||||
type DynamicForwardRequest struct {
|
||||
ListenAddr string
|
||||
}
|
||||
|
||||
type PortForwarder struct {
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
acceptDone chan struct{}
|
||||
|
||||
connWG sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
|
||||
connMu sync.Mutex
|
||||
conns map[net.Conn]struct{}
|
||||
|
||||
errMu sync.Mutex
|
||||
err error
|
||||
|
||||
cleanupOnce sync.Once
|
||||
cleanupFns []func() error
|
||||
}
|
||||
|
||||
const unixForwardProbeTimeout = 200 * time.Millisecond
|
||||
|
||||
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
return client.Dial(network, address)
|
||||
}
|
||||
|
||||
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
||||
return client.Listen(network, address)
|
||||
}
|
||||
|
||||
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return LoginContext(ctx, input)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) {
|
||||
return s.DialTCPContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return s.dialTCPContext(ctx, network, address, nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return s.dialTCPContext(ctx, network, address, s.Close)
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if strings.TrimSpace(network) == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
if strings.TrimSpace(address) == "" {
|
||||
return nil, errors.New("forward address is empty")
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runCancel := func() {}
|
||||
if onCancel != nil {
|
||||
var cancelOnce sync.Once
|
||||
runCancel = func() {
|
||||
cancelOnce.Do(func() {
|
||||
_ = onCancel()
|
||||
})
|
||||
}
|
||||
|
||||
cancelDone := make(chan struct{})
|
||||
defer close(cancelDone)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
case <-cancelDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
dialFunc := dialSSHClient
|
||||
resultCh := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialFunc(ctx, client, network, address)
|
||||
if ctx.Err() != nil && conn != nil {
|
||||
_ = conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case resultCh <- dialResult{conn: conn, err: err}:
|
||||
default:
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.conn, result.err
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
|
||||
return nil, errors.New("local forward listen address is empty")
|
||||
}
|
||||
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(cleanup)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
if cleanup != nil {
|
||||
_ = cleanup()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(cleanup)
|
||||
forwarder.addCleanup(func() error {
|
||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||
})
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := normalizedReq.DialContext
|
||||
if dialContext == nil {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultLoginTimeout,
|
||||
}
|
||||
dialContext = dialer.DialContext
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("dynamic forward listen address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
||||
return s.DialTCPContext(ctx, "tcp", targetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
|
||||
normalized := normalizedForwardRequest{
|
||||
DialContext: req.DialContext,
|
||||
}
|
||||
|
||||
var err error
|
||||
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
|
||||
if err != nil {
|
||||
return normalized, fmt.Errorf("normalize listen address: %w", err)
|
||||
}
|
||||
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
|
||||
if err != nil {
|
||||
return normalized, fmt.Errorf("normalize target address: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(normalized.ListenAddr) == "" {
|
||||
return normalized, errors.New("forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(normalized.TargetAddr) == "" {
|
||||
return normalized, errors.New("forward target address is empty")
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func normalizeForwardNetwork(network string) string {
|
||||
network = strings.ToLower(strings.TrimSpace(network))
|
||||
if network == "" {
|
||||
return "tcp"
|
||||
}
|
||||
return network
|
||||
}
|
||||
|
||||
func isSupportedForwardNetwork(network string) bool {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6", "unix":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseForwardEndpoint(value string) (network string, address string, err error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "tcp", "", nil
|
||||
}
|
||||
|
||||
lowerValue := strings.ToLower(value)
|
||||
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
|
||||
if strings.HasPrefix(lowerValue, prefix) {
|
||||
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
|
||||
address = value[len(prefix):]
|
||||
if !isSupportedForwardNetwork(network) {
|
||||
return "", "", fmt.Errorf("unsupported forward network %q", network)
|
||||
}
|
||||
return network, address, nil
|
||||
}
|
||||
}
|
||||
return "tcp", value, nil
|
||||
}
|
||||
|
||||
func forwardEndpoint(network string, address string) string {
|
||||
network = normalizeForwardNetwork(network)
|
||||
if network == "tcp" {
|
||||
return address
|
||||
}
|
||||
return network + "://" + address
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("dynamic forward listen address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(func() error {
|
||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||
})
|
||||
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
||||
return forwardClient.DialTCPContext(ctx, "tcp", targetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Addr() net.Addr {
|
||||
if f == nil || f.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return f.listener.Addr()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Wait() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
<-f.acceptDone
|
||||
f.connWG.Wait()
|
||||
f.runCleanup()
|
||||
return f.Err()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Err() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
f.errMu.Lock()
|
||||
defer f.errMu.Unlock()
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *PortForwarder) Close() error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
f.closeOnce.Do(func() {
|
||||
if f.cancel != nil {
|
||||
f.cancel()
|
||||
}
|
||||
if f.listener != nil {
|
||||
closeErr = normalizeAlreadyClosedError(f.listener.Close())
|
||||
}
|
||||
f.closeActiveConnections()
|
||||
})
|
||||
|
||||
<-f.acceptDone
|
||||
f.connWG.Wait()
|
||||
f.runCleanup()
|
||||
if closeErr != nil {
|
||||
return closeErr
|
||||
}
|
||||
return f.Err()
|
||||
}
|
||||
|
||||
func newPortForwarder(listener net.Listener) *PortForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &PortForwarder{
|
||||
listener: listener,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
acceptDone: make(chan struct{}),
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
return newDetachedForwardClient(ctx, s.LoginInfo)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) addCleanup(fn func() error) {
|
||||
if f == nil || fn == nil {
|
||||
return
|
||||
}
|
||||
f.cleanupFns = append(f.cleanupFns, fn)
|
||||
}
|
||||
|
||||
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
|
||||
network = normalizeForwardNetwork(network)
|
||||
if network != "unix" {
|
||||
listener, err := net.Listen(network, address)
|
||||
return listener, nil, err
|
||||
}
|
||||
|
||||
if err := removeStaleUnixSocket(address); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup, err := makeUnixSocketCleanup(address)
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
_ = removeUnixSocketPath(address)
|
||||
return nil, nil, err
|
||||
}
|
||||
return listener, cleanup, nil
|
||||
}
|
||||
|
||||
func removeStaleUnixSocket(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode()&os.ModeSocket == 0 {
|
||||
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("local unix forward path %q is already in use", path)
|
||||
}
|
||||
if !isStaleUnixSocketDialError(err) {
|
||||
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
|
||||
}
|
||||
return removeUnixSocketPath(path)
|
||||
}
|
||||
|
||||
func isStaleUnixSocketDialError(err error) bool {
|
||||
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
|
||||
}
|
||||
|
||||
func makeUnixSocketCleanup(path string) (func() error, error) {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func() error {
|
||||
current, err := os.Lstat(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
|
||||
return nil
|
||||
}
|
||||
return removeUnixSocketPath(path)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func removeUnixSocketPath(path string) error {
|
||||
err := os.Remove(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *PortForwarder) runCleanup() {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.cleanupOnce.Do(func() {
|
||||
for _, fn := range f.cleanupFns {
|
||||
f.setError(normalizeAlreadyClosedError(fn()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) {
|
||||
go func() {
|
||||
defer close(f.acceptDone)
|
||||
|
||||
for {
|
||||
conn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
if isClosedListenerError(err) {
|
||||
return
|
||||
}
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.trackConn(conn)
|
||||
f.connWG.Add(1)
|
||||
go func(src net.Conn) {
|
||||
defer f.connWG.Done()
|
||||
defer f.untrackConn(src)
|
||||
|
||||
dst, err := targetDial(f.ctx)
|
||||
if err != nil {
|
||||
f.setError(err)
|
||||
_ = src.Close()
|
||||
return
|
||||
}
|
||||
f.trackConn(dst)
|
||||
defer f.untrackConn(dst)
|
||||
|
||||
f.setError(pipeForwardConnections(src, dst))
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) {
|
||||
go func() {
|
||||
defer close(f.acceptDone)
|
||||
|
||||
for {
|
||||
conn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
if isClosedListenerError(err) {
|
||||
return
|
||||
}
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.trackConn(conn)
|
||||
f.connWG.Add(1)
|
||||
go func(src net.Conn) {
|
||||
defer f.connWG.Done()
|
||||
defer f.untrackConn(src)
|
||||
if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil {
|
||||
f.setError(err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (f *PortForwarder) setError(err error) {
|
||||
if f == nil || err == nil || f.shouldIgnoreError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
f.errMu.Lock()
|
||||
defer f.errMu.Unlock()
|
||||
if f.err == nil {
|
||||
f.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func pipeForwardConnections(left net.Conn, right net.Conn) error {
|
||||
if left == nil || right == nil {
|
||||
if left != nil {
|
||||
_ = left.Close()
|
||||
}
|
||||
if right != nil {
|
||||
_ = right.Close()
|
||||
}
|
||||
return errors.New("forward connection endpoint is nil")
|
||||
}
|
||||
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
var copyWG sync.WaitGroup
|
||||
errCh := make(chan error, 2)
|
||||
copyWG.Add(2)
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
_, err := io.Copy(right, left)
|
||||
errCh <- normalizeAlreadyClosedError(err)
|
||||
closeWrite(right)
|
||||
}()
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
_, err := io.Copy(left, right)
|
||||
errCh <- normalizeAlreadyClosedError(err)
|
||||
closeWrite(left)
|
||||
}()
|
||||
copyWG.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeWrite(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
if writer, ok := conn.(closeWriter); ok {
|
||||
_ = writer.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
func handleDynamicForwardConn(
|
||||
ctx context.Context,
|
||||
src net.Conn,
|
||||
targetDial func(context.Context, string) (net.Conn, error),
|
||||
trackConn func(net.Conn),
|
||||
untrackConn func(net.Conn),
|
||||
) error {
|
||||
if src == nil {
|
||||
return errors.New("dynamic forward source connection is nil")
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if err := negotiateSOCKS5NoAuth(src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetAddr, replyCode, err := readSOCKS5ConnectTarget(src)
|
||||
if err != nil {
|
||||
_ = writeSOCKS5ServerReply(src, replyCode, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
dst, err := targetDial(ctx, targetAddr)
|
||||
if err != nil {
|
||||
_ = writeSOCKS5ServerReply(src, 0x01, nil)
|
||||
return err
|
||||
}
|
||||
if trackConn != nil {
|
||||
trackConn(dst)
|
||||
defer untrackConn(dst)
|
||||
}
|
||||
|
||||
if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil {
|
||||
_ = dst.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return pipeForwardConnections(src, dst)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) trackConn(conn net.Conn) {
|
||||
if f == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
f.connMu.Lock()
|
||||
defer f.connMu.Unlock()
|
||||
f.conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
func (f *PortForwarder) untrackConn(conn net.Conn) {
|
||||
if f == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
f.connMu.Lock()
|
||||
defer f.connMu.Unlock()
|
||||
delete(f.conns, conn)
|
||||
}
|
||||
|
||||
func (f *PortForwarder) closeActiveConnections() {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.connMu.Lock()
|
||||
conns := make([]net.Conn, 0, len(f.conns))
|
||||
for conn := range f.conns {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
f.connMu.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *PortForwarder) shouldIgnoreError(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if normalizeAlreadyClosedError(err) == nil {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func negotiateSOCKS5NoAuth(conn net.Conn) error {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return errors.New("invalid socks5 version")
|
||||
}
|
||||
|
||||
methodCount := int(header[1])
|
||||
methods := make([]byte, methodCount)
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := byte(0xFF)
|
||||
for _, candidate := range methods {
|
||||
if candidate == 0x00 {
|
||||
method = 0x00
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte{0x05, method}); err != nil {
|
||||
return err
|
||||
}
|
||||
if method == 0xFF {
|
||||
return errors.New("socks5 client does not support no-auth method")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) {
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", 0x01, err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return "", 0x01, errors.New("invalid socks5 request version")
|
||||
}
|
||||
if header[1] != 0x01 {
|
||||
return "", 0x07, errors.New("unsupported socks5 command")
|
||||
}
|
||||
|
||||
host, err := readSOCKS5RequestHost(conn, header[3])
|
||||
if err != nil {
|
||||
return "", 0x08, err
|
||||
}
|
||||
|
||||
portBytes := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBytes); err != nil {
|
||||
return "", 0x01, err
|
||||
}
|
||||
port := int(portBytes[0])<<8 | int(portBytes[1])
|
||||
return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil
|
||||
}
|
||||
|
||||
func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) {
|
||||
switch addressType {
|
||||
case 0x01:
|
||||
buffer := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(buffer).String(), nil
|
||||
case 0x03:
|
||||
size := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, size); err != nil {
|
||||
return "", err
|
||||
}
|
||||
buffer := make([]byte, int(size[0]))
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buffer), nil
|
||||
case 0x04:
|
||||
buffer := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(buffer).String(), nil
|
||||
default:
|
||||
return "", errors.New("unsupported socks5 address type")
|
||||
}
|
||||
}
|
||||
|
||||
func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error {
|
||||
reply := []byte{0x05, replyCode, 0x00}
|
||||
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil {
|
||||
if ip4 := tcpAddr.IP.To4(); ip4 != nil {
|
||||
reply = append(reply, 0x01)
|
||||
reply = append(reply, ip4...)
|
||||
} else if ip16 := tcpAddr.IP.To16(); ip16 != nil {
|
||||
reply = append(reply, 0x04)
|
||||
reply = append(reply, ip16...)
|
||||
}
|
||||
if len(reply) > 3 {
|
||||
port := tcpAddr.Port
|
||||
reply = append(reply, byte(port>>8), byte(port))
|
||||
_, err := conn.Write(reply)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
_, err := conn.Write(reply)
|
||||
return err
|
||||
}
|
||||
|
||||
func isClosedListenerError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
+698
@@ -0,0 +1,698 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type stubListener struct {
|
||||
addr net.Addr
|
||||
acceptCh chan net.Conn
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type dialRecord struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
func newStubListener(addr net.Addr) *stubListener {
|
||||
return &stubListener{
|
||||
addr: addr,
|
||||
acceptCh: make(chan net.Conn, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *stubListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn, ok := <-l.acceptCh:
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closeCh:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *stubListener) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.closeCh)
|
||||
close(l.acceptCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *stubListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *stubListener) Push(conn net.Conn) error {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
return net.ErrClosed
|
||||
case l.acceptCh <- conn:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
newDetachedForwardClient = oldNewDetachedForwardClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
var detachedCalls atomic.Int32
|
||||
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
detachedCalls.Add(1)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != baseClient {
|
||||
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
|
||||
}
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
t.Fatal("default local forward should not close the main ssh client")
|
||||
return nil
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
TargetAddr: "example.internal:22",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start local forward: %v", err)
|
||||
}
|
||||
defer forwarder.Close()
|
||||
|
||||
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
|
||||
if string(reply) != "ping" {
|
||||
t.Fatalf("unexpected forwarded reply: %q", string(reply))
|
||||
}
|
||||
if detachedCalls.Load() != 0 {
|
||||
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) {
|
||||
dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
req, err := normalizeForwardRequest(ForwardRequest{
|
||||
"127.0.0.1:10022",
|
||||
"example.internal:22",
|
||||
dialer,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeForwardRequest: %v", err)
|
||||
}
|
||||
if req.ListenNetwork != "tcp" {
|
||||
t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork)
|
||||
}
|
||||
if req.TargetNetwork != "tcp" {
|
||||
t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork)
|
||||
}
|
||||
if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" {
|
||||
t.Fatalf("unexpected normalized request: %+v", req)
|
||||
}
|
||||
if req.DialContext == nil {
|
||||
t.Fatal("expected DialContext to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) {
|
||||
network, address, err := parseForwardEndpoint("tcp:22")
|
||||
if err != nil {
|
||||
t.Fatalf("parseForwardEndpoint: %v", err)
|
||||
}
|
||||
if network != "tcp" {
|
||||
t.Fatalf("network=%q want tcp", network)
|
||||
}
|
||||
if address != "tcp:22" {
|
||||
t.Fatalf("address=%q want tcp:22", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) {
|
||||
network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock")
|
||||
if err != nil {
|
||||
t.Fatalf("parseForwardEndpoint unix: %v", err)
|
||||
}
|
||||
if network != "unix" || address != "/tmp/test-forward.sock" {
|
||||
t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address)
|
||||
}
|
||||
|
||||
network, address, err = parseForwardEndpoint("tcp6://[::1]:2222")
|
||||
if err != nil {
|
||||
t.Fatalf("parseForwardEndpoint tcp6: %v", err)
|
||||
}
|
||||
if network != "tcp6" || address != "[::1]:2222" {
|
||||
t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
||||
oldDialSSHClient := dialSSHClient
|
||||
oldNewDetachedForwardClient := newDetachedForwardClient
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
newDetachedForwardClient = oldNewDetachedForwardClient
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
detachedClient := &ssh.Client{}
|
||||
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
forwardClient := &StarSSH{}
|
||||
forwardClient.setTransport(detachedClient, nil)
|
||||
|
||||
var detachedCalls atomic.Int32
|
||||
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
detachedCalls.Add(1)
|
||||
return forwardClient, nil
|
||||
}
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != detachedClient {
|
||||
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
|
||||
}
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
var closeCalls atomic.Int32
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
TargetAddr: "example.internal:22",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start detached local forward: %v", err)
|
||||
}
|
||||
|
||||
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
|
||||
if string(reply) != "pong" {
|
||||
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
|
||||
}
|
||||
|
||||
if err := forwarder.Close(); err != nil {
|
||||
t.Fatalf("close detached local forward: %v", err)
|
||||
}
|
||||
if detachedCalls.Load() != 1 {
|
||||
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
|
||||
}
|
||||
if closeCalls.Load() != 1 {
|
||||
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
|
||||
}
|
||||
if got := star.snapshotSSHClient(); got != baseClient {
|
||||
t.Fatal("detached local forward should not detach the main ssh client")
|
||||
}
|
||||
if got := forwardClient.snapshotSSHClient(); got != nil {
|
||||
t.Fatal("detached local forward should close its detached ssh client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) {
|
||||
oldListenSSHClient := listenSSHClient
|
||||
t.Cleanup(func() {
|
||||
listenSSHClient = oldListenSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
listener := newStubListener(&net.UnixAddr{
|
||||
Name: "/run/user/0/gnupg/S.gpg-agent",
|
||||
Net: "unix",
|
||||
})
|
||||
|
||||
var listenedNetwork string
|
||||
var listenedAddr string
|
||||
listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
listenedNetwork = network
|
||||
listenedAddr = address
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
var targetNetwork string
|
||||
var targetAddr string
|
||||
forwarder, err := star.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"),
|
||||
TargetAddr: "127.0.0.1:4321",
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
targetNetwork = network
|
||||
targetAddr = address
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("start remote unix forward: %v", err)
|
||||
}
|
||||
defer forwarder.Close()
|
||||
|
||||
srcPeer, forwardedConn := net.Pipe()
|
||||
defer srcPeer.Close()
|
||||
if err := listener.Push(forwardedConn); err != nil {
|
||||
t.Fatalf("push forwarded connection: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte("unix-forward")
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
reply := make([]byte, len(payload))
|
||||
_, _ = io.ReadFull(srcPeer, reply)
|
||||
done <- reply
|
||||
}()
|
||||
|
||||
if _, err := srcPeer.Write(payload); err != nil {
|
||||
t.Fatalf("write source payload: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case reply := <-done:
|
||||
if string(reply) != string(payload) {
|
||||
t.Fatalf("unexpected remote unix forward reply: %q", string(reply))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("remote unix forward did not relay payload")
|
||||
}
|
||||
|
||||
if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" {
|
||||
t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr)
|
||||
}
|
||||
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
|
||||
t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||
}
|
||||
|
||||
oldDialSSHClient := dialSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
var targetNetwork string
|
||||
var targetAddr string
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
targetNetwork = network
|
||||
targetAddr = address
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "forward.sock")
|
||||
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||
if err != nil {
|
||||
t.Fatalf("start local unix forward: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := forwarder.Close()
|
||||
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||
t.Fatalf("close local unix forward: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.DialTimeout("unix", socketPath, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial unix forward listener: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
|
||||
payload := []byte("unix-local-forward")
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write unix forward payload: %v", err)
|
||||
}
|
||||
reply := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read unix forward reply: %v", err)
|
||||
}
|
||||
if string(reply) != string(payload) {
|
||||
t.Fatalf("unexpected unix forward reply: %q", string(reply))
|
||||
}
|
||||
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
|
||||
t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||
}
|
||||
|
||||
oldDialSSHClient := dialSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "cleanup.sock")
|
||||
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||
if err != nil {
|
||||
t.Fatalf("start local unix forward: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Lstat(socketPath); err != nil {
|
||||
t.Fatalf("socket should exist while forward is running: %v", err)
|
||||
}
|
||||
if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Fatalf("close local unix forward: %v", err)
|
||||
}
|
||||
if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("socket path should be removed on close, got err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||
}
|
||||
|
||||
oldDialSSHClient := dialSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
go echoForwardPipe(serverConn)
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "stale.sock")
|
||||
staleListener, err := net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: socketPath,
|
||||
Net: "unix",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create stale unix socket: %v", err)
|
||||
}
|
||||
staleListener.SetUnlinkOnClose(false)
|
||||
if err := staleListener.Close(); err != nil {
|
||||
t.Fatalf("close stale unix socket listener: %v", err)
|
||||
}
|
||||
if _, err := os.Lstat(socketPath); err != nil {
|
||||
t.Fatalf("expected stale unix socket path to remain after close: %v", err)
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
|
||||
if err != nil {
|
||||
t.Fatalf("start local unix forward on stale socket path: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := forwarder.Close()
|
||||
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||
t.Fatalf("close local unix forward: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
reply := make([]byte, len("stale-reuse"))
|
||||
conn, err := net.DialTimeout("unix", socketPath, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial reused unix forward listener: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
if _, err := conn.Write([]byte("stale-reuse")); err != nil {
|
||||
t.Fatalf("write reused unix forward payload: %v", err)
|
||||
}
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read reused unix forward reply: %v", err)
|
||||
}
|
||||
if string(reply) != "stale-reuse" {
|
||||
t.Fatalf("unexpected reply on reused unix forward: %q", string(reply))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||
}
|
||||
|
||||
oldDialSSHClient := dialSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
targetSocketPath := filepath.Join(t.TempDir(), "target.sock")
|
||||
targetListener, err := net.Listen("unix", targetSocketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen target unix socket: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, acceptErr := targetListener.Accept()
|
||||
if acceptErr != nil {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 64)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(buf[:n])
|
||||
done <- buf[:n]
|
||||
}()
|
||||
|
||||
dialRecordCh := make(chan dialRecord, 1)
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
dialRecordCh <- dialRecord{network: network, addr: address}
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
listenSocketPath := filepath.Join(t.TempDir(), "listen.sock")
|
||||
forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("start local unix-to-unix forward: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := forwarder.Close()
|
||||
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||
t.Fatalf("close local unix-to-unix forward: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.DialTimeout("unix", listenSocketPath, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial unix-to-unix listener: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
|
||||
payload := []byte("unix-to-unix")
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write unix-to-unix payload: %v", err)
|
||||
}
|
||||
reply := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read unix-to-unix reply: %v", err)
|
||||
}
|
||||
if string(reply) != string(payload) {
|
||||
t.Fatalf("unexpected unix-to-unix reply: %q", string(reply))
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-done:
|
||||
if string(got) != string(payload) {
|
||||
t.Fatalf("unexpected payload seen by target unix socket: %q", string(got))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("target unix socket did not receive forwarded payload")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-dialRecordCh:
|
||||
if got.network != "unix" || got.addr != targetSocketPath {
|
||||
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not observe unix target dial")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
|
||||
}
|
||||
|
||||
oldDialSSHClient := dialSSHClient
|
||||
t.Cleanup(func() {
|
||||
dialSSHClient = oldDialSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock")
|
||||
targetListener, err := net.Listen("unix", targetSocketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen target unix socket: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, acceptErr := targetListener.Accept()
|
||||
if acceptErr != nil {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 64)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(buf[:n])
|
||||
done <- buf[:n]
|
||||
}()
|
||||
|
||||
dialRecordCh := make(chan dialRecord, 1)
|
||||
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
dialRecordCh <- dialRecord{network: network, addr: address}
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("start local tcp-to-unix forward: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := forwarder.Close()
|
||||
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
|
||||
t.Fatalf("close local tcp-to-unix forward: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix"))
|
||||
if string(reply) != "tcp-to-unix" {
|
||||
t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply))
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-done:
|
||||
if string(got) != "tcp-to-unix" {
|
||||
t.Fatalf("unexpected payload seen by unix target: %q", string(got))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("unix target did not receive forwarded tcp payload")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-dialRecordCh:
|
||||
if got.network != "unix" || got.addr != targetSocketPath {
|
||||
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not observe unix target dial")
|
||||
}
|
||||
}
|
||||
|
||||
func echoForwardPipe(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write(buf[:n])
|
||||
}
|
||||
|
||||
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial forward listener: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write forwarded payload: %v", err)
|
||||
}
|
||||
|
||||
reply := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read forwarded reply: %v", err)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
module b612.me/starssh
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/sys v0.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw=
|
||||
github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
)
|
||||
|
||||
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
|
||||
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
|
||||
|
||||
type HostKeyFingerprintMismatchError struct {
|
||||
Expected []string
|
||||
ActualSHA256 string
|
||||
ActualLegacyMD5 string
|
||||
}
|
||||
|
||||
type AcceptNewHostKeyOptions struct {
|
||||
KnownHostsFile string
|
||||
HashHosts bool
|
||||
IncludeRemoteAddress bool
|
||||
FileMode os.FileMode
|
||||
}
|
||||
|
||||
func (e *HostKeyFingerprintMismatchError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
|
||||
}
|
||||
|
||||
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
trimmed := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
file = strings.TrimSpace(file)
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
trimmed = append(trimmed, file)
|
||||
}
|
||||
if len(trimmed) == 0 {
|
||||
return nil, ErrKnownHostsFileRequired
|
||||
}
|
||||
return knownhosts.New(trimmed...)
|
||||
}
|
||||
|
||||
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
|
||||
KnownHostsFile: file,
|
||||
IncludeRemoteAddress: true,
|
||||
})
|
||||
}
|
||||
|
||||
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
options = normalizeAcceptNewHostKeyOptions(options)
|
||||
if options.KnownHostsFile == "" {
|
||||
return nil, ErrKnownHostsFileRequired
|
||||
}
|
||||
|
||||
state := &acceptNewHostKeyState{
|
||||
file: options.KnownHostsFile,
|
||||
hashHosts: options.HashHosts,
|
||||
includeRemoteAddress: options.IncludeRemoteAddress,
|
||||
fileMode: options.FileMode,
|
||||
}
|
||||
if err := state.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return state.checkHostKey, nil
|
||||
}
|
||||
|
||||
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||||
normalized := make([]string, 0, len(fingerprints))
|
||||
seen := make(map[string]struct{}, len(fingerprints))
|
||||
for _, raw := range fingerprints {
|
||||
fingerprint, err := normalizeHostKeyFingerprint(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fingerprint == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[fingerprint]; exists {
|
||||
continue
|
||||
}
|
||||
seen[fingerprint] = struct{}{}
|
||||
normalized = append(normalized, fingerprint)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil, ErrHostFingerprintRequired
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
actualSHA256 := ssh.FingerprintSHA256(key)
|
||||
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
|
||||
for _, want := range normalized {
|
||||
if want == actualSHA256 || want == actualMD5 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &HostKeyFingerprintMismatchError{
|
||||
Expected: append([]string(nil), normalized...),
|
||||
ActualSHA256: actualSHA256,
|
||||
ActualLegacyMD5: actualMD5,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type acceptNewHostKeyState struct {
|
||||
file string
|
||||
hashHosts bool
|
||||
includeRemoteAddress bool
|
||||
fileMode os.FileMode
|
||||
|
||||
mu sync.Mutex
|
||||
cb ssh.HostKeyCallback
|
||||
}
|
||||
|
||||
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cb != nil {
|
||||
err := s.cb(hostname, remote, key)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var keyErr *knownhosts.KeyError
|
||||
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.reload(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cb == nil {
|
||||
return errors.New("known_hosts callback is nil after reload")
|
||||
}
|
||||
return s.cb(hostname, remote, key)
|
||||
}
|
||||
|
||||
func (s *acceptNewHostKeyState) reload() error {
|
||||
callback, err := loadKnownHostsCallback(s.file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cb = callback
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
|
||||
_, err := os.Stat(file)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return knownhosts.New(file)
|
||||
}
|
||||
|
||||
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
|
||||
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
|
||||
if options.FileMode == 0 {
|
||||
options.FileMode = 0o600
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
|
||||
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
|
||||
if len(addresses) == 0 {
|
||||
return "", errors.New("no hostname or remote address available for known_hosts entry")
|
||||
}
|
||||
|
||||
patterns := make([]string, 0, len(addresses))
|
||||
for _, address := range addresses {
|
||||
normalized := knownhosts.Normalize(address)
|
||||
if hashHosts {
|
||||
patterns = append(patterns, knownhosts.HashHostname(normalized))
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, normalized)
|
||||
}
|
||||
|
||||
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
|
||||
return strings.Join(patterns, ",") + " " + authorizedKey, nil
|
||||
}
|
||||
|
||||
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
|
||||
addresses := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
|
||||
add := func(address string) {
|
||||
address = strings.TrimSpace(address)
|
||||
if address == "" {
|
||||
return
|
||||
}
|
||||
normalized := knownhosts.Normalize(address)
|
||||
if _, exists := seen[normalized]; exists {
|
||||
return
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
|
||||
add(hostname)
|
||||
if includeRemoteAddress && remote != nil {
|
||||
add(remote.String())
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
|
||||
if strings.TrimSpace(file) == "" {
|
||||
return ErrKnownHostsFileRequired
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return errors.New("known_hosts line is empty")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(file)
|
||||
if dir != "." && dir != "" {
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer handle.Close()
|
||||
|
||||
if _, err := handle.WriteString(line + "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return handle.Chmod(mode)
|
||||
}
|
||||
|
||||
func normalizeHostKeyFingerprint(raw string) (string, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
|
||||
suffix := strings.TrimSpace(value[len("SHA256:"):])
|
||||
if suffix == "" {
|
||||
return "", ErrHostFingerprintRequired
|
||||
}
|
||||
return "SHA256:" + suffix, nil
|
||||
}
|
||||
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||||
suffix := strings.TrimSpace(value[len("MD5:"):])
|
||||
if suffix == "" {
|
||||
return "", ErrHostFingerprintRequired
|
||||
}
|
||||
return normalizeMD5Fingerprint("MD5:" + suffix), nil
|
||||
}
|
||||
if strings.Count(value, ":") >= 2 {
|
||||
return normalizeMD5Fingerprint("MD5:" + value), nil
|
||||
}
|
||||
return "SHA256:" + value, nil
|
||||
}
|
||||
|
||||
func normalizeMD5Fingerprint(value string) string {
|
||||
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||||
return "MD5:" + strings.ToLower(value)
|
||||
}
|
||||
return "MD5:" + strings.ToLower(value[len("MD5:"):])
|
||||
}
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarSSH) Ping() error {
|
||||
return s.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (s *StarSSH) PingContext(ctx context.Context) error {
|
||||
return s.pingContext(ctx, nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) PingContextCloseOnCancel(ctx context.Context) error {
|
||||
return s.pingContext(ctx, s.Close)
|
||||
}
|
||||
|
||||
func (s *StarSSH) pingContext(ctx context.Context, onCancel func() error) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runCancel := func() {}
|
||||
if onCancel != nil {
|
||||
var cancelOnce sync.Once
|
||||
runCancel = func() {
|
||||
cancelOnce.Do(func() {
|
||||
_ = onCancel()
|
||||
})
|
||||
}
|
||||
|
||||
cancelDone := make(chan struct{})
|
||||
defer close(cancelDone)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
case <-cancelDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
requestFunc := sendKeepAliveRequest
|
||||
pingErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := requestFunc(ctx, client)
|
||||
select {
|
||||
case pingErr <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-pingErr:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) startAutoKeepAlive() {
|
||||
if s == nil || s.snapshotSSHClient() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
interval := s.LoginInfo.KeepAliveInterval
|
||||
if interval <= 0 {
|
||||
return
|
||||
}
|
||||
timeout := s.LoginInfo.KeepAliveTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultKeepAliveTimeout
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
|
||||
s.keepaliveMu.Lock()
|
||||
if s.keepaliveStop != nil {
|
||||
s.keepaliveMu.Unlock()
|
||||
return
|
||||
}
|
||||
s.keepaliveStop = stop
|
||||
s.keepaliveDone = done
|
||||
s.keepaliveMu.Unlock()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
err := s.pingContext(pingCtx, nil)
|
||||
cancel()
|
||||
if err != nil {
|
||||
s.closeFromKeepAlive()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *StarSSH) stopAutoKeepAlive() {
|
||||
stop, done := s.takeKeepaliveHandles()
|
||||
if stop != nil {
|
||||
close(stop)
|
||||
}
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) closeFromKeepAlive() {
|
||||
_ = s.closeTransport(false)
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestAutoKeepAliveTimeoutDoesNotDeadlock(t *testing.T) {
|
||||
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
closeSSHClient = func(client sshClientRequester) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &ssh.Client{}
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
KeepAliveInterval: 10 * time.Millisecond,
|
||||
KeepAliveTimeout: 20 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
star.setTransport(client, nil)
|
||||
star.startAutoKeepAlive()
|
||||
|
||||
star.keepaliveMu.Lock()
|
||||
done := star.keepaliveDone
|
||||
star.keepaliveMu.Unlock()
|
||||
if done == nil {
|
||||
t.Fatal("keepalive did not start")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("keepalive goroutine did not exit after keepalive timeout")
|
||||
}
|
||||
|
||||
if client := star.snapshotSSHClient(); client != nil {
|
||||
t.Fatal("ssh client should be closed after keepalive timeout")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||
|
||||
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
return loginWithContext(ctx, info)
|
||||
}
|
||||
|
||||
func Login(info LoginInput) (*StarSSH, error) {
|
||||
return LoginContext(context.Background(), info)
|
||||
}
|
||||
|
||||
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
info = normalizeLoginInput(info)
|
||||
if info.HostKeyCallback == nil {
|
||||
return nil, ErrHostKeyCallbackRequired
|
||||
}
|
||||
|
||||
authTimeout := effectiveLoginTimeout(info)
|
||||
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||
defer cancel()
|
||||
|
||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if shouldRetrySSHAgentAuth(info, order) {
|
||||
agentAttempt := newSSHAgentAuthAttempt()
|
||||
for {
|
||||
agentAttempt.begin()
|
||||
sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt)
|
||||
if err == nil {
|
||||
return sshInfo, nil
|
||||
}
|
||||
if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil {
|
||||
continue
|
||||
}
|
||||
return sshInfo, err
|
||||
}
|
||||
}
|
||||
|
||||
return loginOnceWithContext(loginCtx, info, authTimeout, nil)
|
||||
}
|
||||
|
||||
func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) {
|
||||
sshInfo := &StarSSH{
|
||||
LoginInfo: info,
|
||||
}
|
||||
|
||||
auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authCleanup != nil {
|
||||
defer authCleanup()
|
||||
}
|
||||
|
||||
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
sshInfo.PublicKey = key
|
||||
sshInfo.RemoteAddr = remote
|
||||
sshInfo.Hostname = hostname
|
||||
|
||||
return info.HostKeyCallback(hostname, remote, key)
|
||||
}
|
||||
|
||||
bannerCallback := func(banner string) error {
|
||||
sshInfo.Banner = banner
|
||||
if info.BannerCallback != nil {
|
||||
return info.BannerCallback(banner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: info.User,
|
||||
Auth: auth,
|
||||
Timeout: authTimeout,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
BannerCallback: bannerCallback,
|
||||
}
|
||||
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
|
||||
clientConfig.Config = ssh.Config{
|
||||
Ciphers: info.Ciphers,
|
||||
MACs: info.MACs,
|
||||
KeyExchanges: info.KeyExchanges,
|
||||
}
|
||||
}
|
||||
|
||||
targetAddr := joinHostPort(info.Addr, info.Port)
|
||||
rawConn, upstream, err := dialTargetConn(ctx, info)
|
||||
if err != nil {
|
||||
return sshInfo, err
|
||||
}
|
||||
restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
if upstream != nil {
|
||||
_ = upstream.Close()
|
||||
}
|
||||
return sshInfo, err
|
||||
}
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
|
||||
sshInfo.setTransport(client, upstream)
|
||||
if sshInfo.PublicKey != nil {
|
||||
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
|
||||
}
|
||||
sshInfo.startAutoKeepAlive()
|
||||
|
||||
return sshInfo, nil
|
||||
}
|
||||
|
||||
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if timeout <= 0 {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
|
||||
info := LoginInput{
|
||||
Addr: host,
|
||||
Port: port,
|
||||
Timeout: timeout,
|
||||
DialTimeout: timeout,
|
||||
User: user,
|
||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||
}
|
||||
|
||||
if prikeyPath != "" {
|
||||
prikey, err := os.ReadFile(prikeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Prikey = string(prikey)
|
||||
if passwd != "" {
|
||||
info.PrikeyPwd = passwd
|
||||
}
|
||||
} else {
|
||||
info.Password = passwd
|
||||
}
|
||||
|
||||
return Login(info)
|
||||
}
|
||||
|
||||
func normalizeLoginInput(info LoginInput) LoginInput {
|
||||
if info.Port <= 0 {
|
||||
info.Port = defaultSSHPort
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||
if info.Timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
return info.Timeout
|
||||
}
|
||||
|
||||
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||
switch {
|
||||
case info.DialTimeout < 0:
|
||||
return 0
|
||||
case info.DialTimeout > 0:
|
||||
return info.DialTimeout
|
||||
case info.Timeout > 0:
|
||||
return info.Timeout
|
||||
default:
|
||||
return defaultLoginTimeout
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,763 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
sshagent "golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||
info := normalizeLoginInput(LoginInput{})
|
||||
if info.Port != defaultSSHPort {
|
||||
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||
}
|
||||
if info.Timeout != 0 {
|
||||
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||
}
|
||||
if info.DialTimeout != 0 {
|
||||
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||
}
|
||||
if info.SSHAgentTimeout != 0 {
|
||||
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
|
||||
}
|
||||
if info.SSHAgentForwardTimeout != 0 {
|
||||
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||
}
|
||||
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveDialTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info LoginInput
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "default fallback",
|
||||
info: LoginInput{},
|
||||
want: defaultLoginTimeout,
|
||||
},
|
||||
{
|
||||
name: "reuse timeout when dial timeout omitted",
|
||||
info: LoginInput{Timeout: 9 * time.Second},
|
||||
want: 9 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "explicit dial timeout wins",
|
||||
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||
want: 3 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "negative dial timeout disables default dial deadline",
|
||||
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveSSHAgentTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info LoginInput
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "default fallback without auth timeout",
|
||||
info: LoginInput{},
|
||||
want: defaultSSHAgentTimeout,
|
||||
},
|
||||
{
|
||||
name: "auth timeout does not cap default",
|
||||
info: LoginInput{Timeout: 9 * time.Second},
|
||||
want: defaultSSHAgentTimeout,
|
||||
},
|
||||
{
|
||||
name: "explicit agent timeout wins",
|
||||
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
|
||||
want: 90 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "negative agent timeout disables operation deadline",
|
||||
info: LoginInput{SSHAgentTimeout: -1},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
|
||||
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
|
||||
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
|
||||
t.Fatalf("zero forward timeout should stay zero, got %v", got)
|
||||
}
|
||||
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
|
||||
t.Fatalf("expected explicit forward timeout, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
|
||||
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||
t.Cleanup(func() {
|
||||
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||
})
|
||||
|
||||
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
|
||||
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||
captured = timeouts
|
||||
return ssh.Password("agent"), nil, nil
|
||||
}
|
||||
|
||||
info := LoginInput{
|
||||
Timeout: 0,
|
||||
DialTimeout: 11 * time.Second,
|
||||
SSHAgentTimeout: 90 * time.Second,
|
||||
SSHAgentForwardTimeout: 4 * time.Second,
|
||||
IdentityAgent: "/tmp/custom-agent.sock",
|
||||
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||
}
|
||||
auth, cleanup, err := buildAuthMethods(info)
|
||||
if err != nil {
|
||||
t.Fatalf("buildAuthMethods: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 1 {
|
||||
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||
}
|
||||
if captured.Dial != 11*time.Second {
|
||||
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
|
||||
}
|
||||
if captured.Operation != 90*time.Second {
|
||||
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
|
||||
}
|
||||
if captured.Forward != 4*time.Second {
|
||||
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
|
||||
}
|
||||
if captured.Endpoint != "/tmp/custom-agent.sock" {
|
||||
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
|
||||
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||
t.Cleanup(func() {
|
||||
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||
})
|
||||
|
||||
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||
return ssh.Password("agent"), nil, nil
|
||||
}
|
||||
|
||||
auth, cleanup, err := buildAuthMethods(LoginInput{
|
||||
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildAuthMethods: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 1 {
|
||||
t.Fatalf("auth methods=%d, want 1", len(auth))
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
|
||||
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
|
||||
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
|
||||
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
|
||||
}
|
||||
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
|
||||
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
|
||||
}
|
||||
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
|
||||
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
|
||||
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||
t.Cleanup(func() {
|
||||
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||
})
|
||||
|
||||
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||
if timeouts.SignFailure == nil {
|
||||
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
|
||||
}
|
||||
if timeouts.SkipFingerprints != nil {
|
||||
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
|
||||
}
|
||||
return ssh.Password("agent"), nil, nil
|
||||
}
|
||||
|
||||
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||
Password: "secret",
|
||||
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
|
||||
}, newSSHAgentAuthAttempt())
|
||||
if err != nil {
|
||||
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 2 {
|
||||
t.Fatalf("auth methods=%d want 2", len(auth))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
|
||||
attempt := newSSHAgentAuthAttempt()
|
||||
attempt.skipFingerprint("SHA256:test")
|
||||
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
|
||||
}
|
||||
attempt.begin()
|
||||
if err := checkSSHAgentRetryPending(attempt); err != nil {
|
||||
t.Fatalf("retry should reset on next attempt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
|
||||
signer := mustGenerateTestSigner(t)
|
||||
attempt := newSSHAgentAuthAttempt()
|
||||
callback := privateKeySignersCallback(signer, attempt)
|
||||
|
||||
signers, err := callback()
|
||||
if err != nil {
|
||||
t.Fatalf("private key callback before retry: %v", err)
|
||||
}
|
||||
if len(signers) != 1 || signers[0] != signer {
|
||||
t.Fatalf("private key callback returned %#v, want original signer", signers)
|
||||
}
|
||||
|
||||
attempt.skipFingerprint("SHA256:test")
|
||||
signers, err = callback()
|
||||
if !errors.Is(err, errRetrySSHAgentAuth) {
|
||||
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
|
||||
}
|
||||
if signers != nil {
|
||||
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
|
||||
}
|
||||
|
||||
attempt.begin()
|
||||
signers, err = callback()
|
||||
if err != nil {
|
||||
t.Fatalf("private key callback after retry reset: %v", err)
|
||||
}
|
||||
if len(signers) != 1 || signers[0] != signer {
|
||||
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
|
||||
firstSigner := mustGenerateTestSigner(t)
|
||||
secondSigner := mustGenerateTestSigner(t)
|
||||
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
|
||||
|
||||
attempt := newSSHAgentAuthAttempt()
|
||||
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
||||
SignFailure: attempt.recordSignFailure,
|
||||
SkipFingerprints: attempt.skipSnapshot(),
|
||||
})
|
||||
if len(firstMethods) != 2 {
|
||||
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
|
||||
}
|
||||
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
|
||||
}
|
||||
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
||||
SignFailure: attempt.recordSignFailure,
|
||||
SkipFingerprints: attempt.skipSnapshot(),
|
||||
})
|
||||
if len(secondMethods) != 1 {
|
||||
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
|
||||
}
|
||||
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||
t.Fatalf("second auth method did not skip failed first key")
|
||||
}
|
||||
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
|
||||
if err != nil {
|
||||
t.Fatalf("second signer Sign: %v", err)
|
||||
}
|
||||
if signature == nil {
|
||||
t.Fatal("second signer returned nil signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
|
||||
firstSigner := mustGenerateTestSigner(t)
|
||||
secondSigner := mustGenerateTestSigner(t)
|
||||
wantErr := errors.New("first agent key cannot sign")
|
||||
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
|
||||
|
||||
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||
t.Cleanup(func() {
|
||||
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||
})
|
||||
|
||||
var buildCalls int
|
||||
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||
buildCalls++
|
||||
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
|
||||
if buildCalls == 1 {
|
||||
if len(filteredSigners) != 2 {
|
||||
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
|
||||
}
|
||||
return ssh.PublicKeys(filteredSigners...), nil, nil
|
||||
}
|
||||
if len(filteredSigners) != 1 {
|
||||
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
|
||||
}
|
||||
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||
t.Fatal("retry build did not skip failed signer")
|
||||
}
|
||||
return ssh.PublicKeys(filteredSigners...), nil, nil
|
||||
}
|
||||
|
||||
attempt := newSSHAgentAuthAttempt()
|
||||
attempt.begin()
|
||||
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||
}, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 1 {
|
||||
t.Fatalf("first auth methods=%d want 1", len(auth))
|
||||
}
|
||||
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
|
||||
t.Fatalf("raw failing signer err=%v", err)
|
||||
}
|
||||
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
|
||||
SignFailure: attempt.recordSignFailure,
|
||||
})[0]
|
||||
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
||||
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
|
||||
}
|
||||
|
||||
attempt.begin()
|
||||
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
|
||||
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||
}, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 1 {
|
||||
t.Fatalf("retry auth methods=%d want 1", len(auth))
|
||||
}
|
||||
if buildCalls != 2 {
|
||||
t.Fatalf("build calls=%d want 2", buildCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
|
||||
plainSigner := mustGenerateTestSigner(t)
|
||||
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
|
||||
|
||||
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||
}
|
||||
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
|
||||
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
|
||||
plainSigner := mustGenerateTestSigner(t)
|
||||
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
|
||||
|
||||
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||
}
|
||||
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
|
||||
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
|
||||
firstSigner := mustGenerateTestSigner(t)
|
||||
secondSigner := mustGenerateTestSigner(t)
|
||||
|
||||
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
||||
}
|
||||
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
|
||||
t.Fatalf("first signer changed order without hints")
|
||||
}
|
||||
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
||||
t.Fatalf("second signer changed order without hints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
|
||||
signer := mustGenerateTestSigner(t)
|
||||
wantErr := errors.New("agent refused operation")
|
||||
var debugCalls int
|
||||
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
|
||||
Resolved: resolvedSSHAgentEndpoint{
|
||||
Endpoint: "/tmp/debug-agent.sock",
|
||||
Source: "identity-agent",
|
||||
Network: "unix",
|
||||
},
|
||||
Debug: func(event SSHAgentDebugEvent) {
|
||||
debugCalls++
|
||||
if event.Step != "auth" || event.Phase != "sign" {
|
||||
t.Fatalf("unexpected debug event: %+v", event)
|
||||
}
|
||||
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
|
||||
t.Fatalf("unexpected endpoint details: %+v", event)
|
||||
}
|
||||
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
|
||||
t.Fatalf("unexpected sign status: %+v", event)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("Sign err=%v want original signer error", err)
|
||||
}
|
||||
if debugCalls != 1 {
|
||||
t.Fatalf("debug calls=%d want 1", debugCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
|
||||
signer := mustGenerateRSATestSigner(t)
|
||||
spy := &testAlgorithmSpySigner{Signer: signer}
|
||||
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
||||
if !ok {
|
||||
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
||||
}
|
||||
|
||||
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
||||
if err != nil {
|
||||
t.Fatalf("SignWithAlgorithm: %v", err)
|
||||
}
|
||||
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
|
||||
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
|
||||
}
|
||||
if signature.Format != ssh.KeyAlgoRSASHA256 {
|
||||
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
|
||||
signer := mustGenerateRSATestSigner(t)
|
||||
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSignerWithAlgorithms: %v", err)
|
||||
}
|
||||
spy := &testMultiAlgorithmSpySigner{
|
||||
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
|
||||
}
|
||||
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
||||
if !ok {
|
||||
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
||||
}
|
||||
|
||||
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
||||
if err != nil {
|
||||
t.Fatalf("SignWithAlgorithm: %v", err)
|
||||
}
|
||||
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
|
||||
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
|
||||
}
|
||||
if signature.Format != ssh.KeyAlgoRSA {
|
||||
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
|
||||
}
|
||||
}
|
||||
|
||||
type deadlineSpyConn struct {
|
||||
net.Conn
|
||||
mu sync.Mutex
|
||||
deadlines []time.Time
|
||||
readErr error
|
||||
writeErr error
|
||||
}
|
||||
|
||||
type testFailingSigner struct {
|
||||
ssh.Signer
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||
return nil, s.err
|
||||
}
|
||||
|
||||
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||
return nil, s.err
|
||||
}
|
||||
|
||||
type testAlgorithmSpySigner struct {
|
||||
ssh.Signer
|
||||
lastAlgorithm string
|
||||
}
|
||||
|
||||
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||
s.lastAlgorithm = algorithm
|
||||
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
|
||||
}
|
||||
|
||||
type testMultiAlgorithmSpySigner struct {
|
||||
*testAlgorithmSpySigner
|
||||
}
|
||||
|
||||
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
|
||||
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
|
||||
return multiAlgorithmSigner.Algorithms()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
|
||||
t.Helper()
|
||||
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate test private key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
t.Fatalf("new test signer: %v", err)
|
||||
}
|
||||
return signer
|
||||
}
|
||||
|
||||
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
|
||||
t.Helper()
|
||||
baseSigner := mustGenerateTestSigner(t)
|
||||
publicKey := baseSigner.PublicKey()
|
||||
return &commentedTestSigner{
|
||||
Signer: baseSigner,
|
||||
publicKey: &sshagent.Key{
|
||||
Format: publicKey.Type(),
|
||||
Blob: publicKey.Marshal(),
|
||||
Comment: comment,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type commentedTestSigner struct {
|
||||
ssh.Signer
|
||||
publicKey ssh.PublicKey
|
||||
}
|
||||
|
||||
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
|
||||
return s.publicKey
|
||||
}
|
||||
|
||||
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate rsa test private key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
t.Fatalf("new rsa test signer: %v", err)
|
||||
}
|
||||
return signer
|
||||
}
|
||||
|
||||
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.deadlines = append(c.deadlines, deadline)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *deadlineSpyConn) deadlineCount() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.deadlines)
|
||||
}
|
||||
|
||||
func (c *deadlineSpyConn) firstDeadline() time.Time {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.deadlines[0]
|
||||
}
|
||||
|
||||
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
|
||||
if c.readErr != nil {
|
||||
return 0, c.readErr
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
|
||||
if c.writeErr != nil {
|
||||
return 0, c.writeErr
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
|
||||
spy := &deadlineSpyConn{readErr: io.EOF}
|
||||
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
||||
buf := make([]byte, 1)
|
||||
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("Read err=%v", err)
|
||||
}
|
||||
if spy.deadlineCount() != 1 {
|
||||
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
||||
}
|
||||
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
|
||||
t.Fatalf("deadline=%v should be in the future", firstDeadline)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
|
||||
spy := &deadlineSpyConn{}
|
||||
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
||||
if _, err := conn.Write([]byte("x")); err != nil {
|
||||
t.Fatalf("Write err=%v", err)
|
||||
}
|
||||
if spy.deadlineCount() != 1 {
|
||||
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
|
||||
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
||||
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
||||
}
|
||||
if resolved.Endpoint != "/tmp/identity-agent.sock" {
|
||||
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
||||
}
|
||||
if resolved.Source != "identity-agent" {
|
||||
t.Fatalf("source=%q", resolved.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
|
||||
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
||||
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
||||
}
|
||||
if resolved.Endpoint != "/tmp/env-agent.sock" {
|
||||
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
||||
}
|
||||
if resolved.Source != "SSH_AUTH_SOCK" {
|
||||
t.Fatalf("source=%q", resolved.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
||||
t.Cleanup(func() {
|
||||
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
||||
})
|
||||
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
|
||||
Operation: 20 * time.Millisecond,
|
||||
Endpoint: "/tmp/hung-agent.sock",
|
||||
})
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if !errors.Is(err, ErrSSHAgentTimeout) {
|
||||
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
|
||||
socketPath := tempUnixSocketPath(t)
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen unix: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var events []SSHAgentDebugEvent
|
||||
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
|
||||
Dial: time.Second,
|
||||
Operation: time.Second,
|
||||
Endpoint: socketPath,
|
||||
Debug: func(event SSHAgentDebugEvent) {
|
||||
events = append(events, event)
|
||||
},
|
||||
})
|
||||
<-done
|
||||
|
||||
if len(events) == 0 {
|
||||
t.Fatal("expected debug events")
|
||||
}
|
||||
if events[0].Step != "auth" || events[0].Phase != "dial" {
|
||||
t.Fatalf("unexpected first event: %+v", events[0])
|
||||
}
|
||||
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
|
||||
t.Fatalf("unexpected endpoint event: %+v", events[0])
|
||||
}
|
||||
}
|
||||
|
||||
func tempUnixSocketPath(t *testing.T) string {
|
||||
t.Helper()
|
||||
path := t.TempDir() + "/agent.sock"
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(path)
|
||||
})
|
||||
return path
|
||||
}
|
||||
@@ -0,0 +1,466 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultExecPoolMaxOpenConns = 4
|
||||
|
||||
var ErrExecPoolClosed = errors.New("exec pool is closed")
|
||||
|
||||
type ExecPoolConfig struct {
|
||||
Login LoginInput
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
MaxIdleTime time.Duration
|
||||
DisableHealthCheck bool
|
||||
HealthCheckTimeout time.Duration
|
||||
}
|
||||
|
||||
type ExecPoolStats struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
MaxIdleTime time.Duration
|
||||
OpenConns int
|
||||
IdleConns int
|
||||
InUseConns int
|
||||
}
|
||||
|
||||
type ExecPool struct {
|
||||
loginInfo LoginInput
|
||||
maxOpen int
|
||||
maxIdle int
|
||||
maxIdleTime time.Duration
|
||||
|
||||
idle chan *pooledClient
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
healthCheckOnAcquire bool
|
||||
healthCheckTimeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
open int
|
||||
closed bool
|
||||
}
|
||||
|
||||
type pooledClient struct {
|
||||
client *StarSSH
|
||||
idleAt time.Time
|
||||
}
|
||||
|
||||
func NewExecPool(config ExecPoolConfig) *ExecPool {
|
||||
maxOpen := config.MaxOpenConns
|
||||
if maxOpen <= 0 {
|
||||
maxOpen = defaultExecPoolMaxOpenConns
|
||||
}
|
||||
|
||||
maxIdle := config.MaxIdleConns
|
||||
if maxIdle <= 0 || maxIdle > maxOpen {
|
||||
maxIdle = maxOpen
|
||||
}
|
||||
|
||||
return &ExecPool{
|
||||
loginInfo: config.Login,
|
||||
maxOpen: maxOpen,
|
||||
maxIdle: maxIdle,
|
||||
maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime),
|
||||
idle: make(chan *pooledClient, maxIdle),
|
||||
done: make(chan struct{}),
|
||||
healthCheckOnAcquire: !config.DisableHealthCheck,
|
||||
healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) {
|
||||
client, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, execErr := client.Exec(ctx, req)
|
||||
if execErr != nil {
|
||||
p.Discard(client)
|
||||
return result, execErr
|
||||
}
|
||||
if releaseErr := p.Release(client); releaseErr != nil {
|
||||
return result, releaseErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) {
|
||||
return p.Exec(ctx, ExecRequest{
|
||||
Command: command,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
|
||||
client, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, execErr := client.ExecStream(ctx, req, onChunk)
|
||||
if execErr != nil {
|
||||
p.Discard(client)
|
||||
return result, execErr
|
||||
}
|
||||
if releaseErr := p.Release(client); releaseErr != nil {
|
||||
return result, releaseErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error {
|
||||
if p == nil {
|
||||
return errors.New("exec pool is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
targetIdle = p.normalizeWarmUpTarget(targetIdle)
|
||||
if targetIdle == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idleCount, create, err := p.tryWarmUp(targetIdle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idleCount >= targetIdle || !create {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := LoginContext(ctx, p.loginInfo)
|
||||
if err != nil {
|
||||
p.releaseSlot()
|
||||
return err
|
||||
}
|
||||
if err := p.Release(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) {
|
||||
if p == nil {
|
||||
return nil, errors.New("exec pool is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
for {
|
||||
idleClient, create, err := p.tryAcquire()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if idleClient != nil {
|
||||
client, ok := p.takeIdleClient(ctx, idleClient)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
if create {
|
||||
conn, err := LoginContext(ctx, p.loginInfo)
|
||||
if err != nil {
|
||||
p.releaseSlot()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-p.done:
|
||||
return nil, ErrExecPoolClosed
|
||||
case idleClient = <-p.idle:
|
||||
if idleClient == nil {
|
||||
continue
|
||||
}
|
||||
client, ok := p.takeIdleClient(ctx, idleClient)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Release(client *StarSSH) error {
|
||||
if p == nil {
|
||||
return errors.New("exec pool is nil")
|
||||
}
|
||||
if client == nil {
|
||||
p.releaseSlot()
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
p.closeClient(client)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case p.idle <- &pooledClient{
|
||||
client: client,
|
||||
idleAt: time.Now(),
|
||||
}:
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
default:
|
||||
p.mu.Unlock()
|
||||
p.closeClient(client)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Discard(client *StarSSH) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
if client == nil {
|
||||
p.releaseSlot()
|
||||
return
|
||||
}
|
||||
p.closeClient(client)
|
||||
}
|
||||
|
||||
func (p *ExecPool) Stats() ExecPoolStats {
|
||||
if p == nil {
|
||||
return ExecPoolStats{}
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
idleCount := len(p.idle)
|
||||
openCount := p.open
|
||||
inUseCount := openCount - idleCount
|
||||
if inUseCount < 0 {
|
||||
inUseCount = 0
|
||||
}
|
||||
|
||||
return ExecPoolStats{
|
||||
MaxOpenConns: p.maxOpen,
|
||||
MaxIdleConns: p.maxIdle,
|
||||
MaxIdleTime: p.maxIdleTime,
|
||||
OpenConns: openCount,
|
||||
IdleConns: idleCount,
|
||||
InUseConns: inUseCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) Close() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
p.closeOnce.Do(func() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
idleClients := p.drainIdleLocked()
|
||||
p.mu.Unlock()
|
||||
close(p.done)
|
||||
|
||||
for _, client := range idleClients {
|
||||
if err := client.Close(); err != nil && closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
})
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (p *ExecPool) CloseIdle() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
idleClients := p.drainIdleLocked()
|
||||
p.mu.Unlock()
|
||||
|
||||
var closeErr error
|
||||
for _, client := range idleClients {
|
||||
if err := client.Close(); err != nil && closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil, false, ErrExecPoolClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case client := <-p.idle:
|
||||
return client, false, nil
|
||||
default:
|
||||
}
|
||||
|
||||
if p.open < p.maxOpen {
|
||||
p.open++
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return 0, false, ErrExecPoolClosed
|
||||
}
|
||||
|
||||
idleCount := len(p.idle)
|
||||
if idleCount >= targetIdle {
|
||||
return idleCount, false, nil
|
||||
}
|
||||
if p.open >= p.maxOpen {
|
||||
return idleCount, false, nil
|
||||
}
|
||||
|
||||
p.open++
|
||||
return idleCount, true, nil
|
||||
}
|
||||
|
||||
func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int {
|
||||
if p == nil || p.maxIdle <= 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIdle <= 0 {
|
||||
return p.maxIdle
|
||||
}
|
||||
if targetIdle > p.maxIdle {
|
||||
return p.maxIdle
|
||||
}
|
||||
return targetIdle
|
||||
}
|
||||
|
||||
func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) {
|
||||
if idleClient == nil {
|
||||
return nil, false
|
||||
}
|
||||
if idleClient.client == nil {
|
||||
p.releaseSlot()
|
||||
return nil, false
|
||||
}
|
||||
if p.isIdleExpired(idleClient) {
|
||||
p.closePooledClient(idleClient)
|
||||
return nil, false
|
||||
}
|
||||
if err := p.healthCheckClient(ctx, idleClient.client); err != nil {
|
||||
p.closePooledClient(idleClient)
|
||||
return nil, false
|
||||
}
|
||||
return idleClient.client, true
|
||||
}
|
||||
|
||||
func (p *ExecPool) isIdleExpired(client *pooledClient) bool {
|
||||
if p == nil || client == nil || client.client == nil {
|
||||
return false
|
||||
}
|
||||
if p.maxIdleTime <= 0 || client.idleAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Since(client.idleAt) >= p.maxIdleTime
|
||||
}
|
||||
|
||||
func (p *ExecPool) drainIdleLocked() []*StarSSH {
|
||||
clients := make([]*StarSSH, 0, len(p.idle))
|
||||
for {
|
||||
select {
|
||||
case idleClient := <-p.idle:
|
||||
if p.open > 0 {
|
||||
p.open--
|
||||
}
|
||||
if idleClient == nil || idleClient.client == nil {
|
||||
continue
|
||||
}
|
||||
clients = append(clients, idleClient.client)
|
||||
default:
|
||||
return clients
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) releaseSlot() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.open > 0 {
|
||||
p.open--
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExecPool) closeClient(client *StarSSH) {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
p.releaseSlot()
|
||||
}
|
||||
|
||||
func (p *ExecPool) closePooledClient(client *pooledClient) {
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
if client.client == nil {
|
||||
p.releaseSlot()
|
||||
return
|
||||
}
|
||||
p.closeClient(client.client)
|
||||
}
|
||||
|
||||
func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error {
|
||||
if client == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
}
|
||||
if !p.healthCheckOnAcquire {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
timeout := p.healthCheckTimeout
|
||||
if timeout > 0 {
|
||||
healthCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ctx = healthCtx
|
||||
}
|
||||
return client.PingContext(ctx)
|
||||
}
|
||||
|
||||
func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout <= 0 {
|
||||
return defaultKeepAliveTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func normalizeMaxIdleTime(timeout time.Duration) time.Duration {
|
||||
if timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
+145
@@ -0,0 +1,145 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return client.NewSession()
|
||||
}
|
||||
|
||||
var requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
||||
return session.RequestPty(config.Term, config.Rows, config.Columns, config.Modes)
|
||||
}
|
||||
|
||||
func (s *StarSSH) Close() error {
|
||||
return s.closeTransport(true)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewSession() (*ssh.Session, error) {
|
||||
return s.NewPTYSession(nil)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := NewExecSession(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.maybeRequestAgentForwarding(session); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
|
||||
client, err := s.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := NewPTYSession(client, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.maybeRequestAgentForwarding(session); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
return NewExecSession(client)
|
||||
}
|
||||
|
||||
func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
return newSSHSession(client)
|
||||
}
|
||||
|
||||
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
return NewPTYSession(client, nil)
|
||||
}
|
||||
|
||||
func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
|
||||
session, err := newSSHSession(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := normalizeTerminalConfig(config)
|
||||
if err := requestSessionPTY(session, cfg); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func normalizeTerminalConfig(config *TerminalConfig) TerminalConfig {
|
||||
cfg := TerminalConfig{
|
||||
Term: defaultPTYTerm,
|
||||
Rows: defaultPTYRows,
|
||||
Columns: defaultPTYColumns,
|
||||
Modes: ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
},
|
||||
}
|
||||
if config == nil {
|
||||
return cfg
|
||||
}
|
||||
|
||||
if strings.TrimSpace(config.Term) != "" {
|
||||
cfg.Term = config.Term
|
||||
}
|
||||
if config.Rows > 0 {
|
||||
cfg.Rows = config.Rows
|
||||
}
|
||||
if config.Columns > 0 {
|
||||
cfg.Columns = config.Columns
|
||||
}
|
||||
if len(config.Modes) > 0 {
|
||||
cfg.Modes = config.Modes
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func normalizeAlreadyClosedError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func closeUpstream(upstream *StarSSH) error {
|
||||
if upstream == nil {
|
||||
return nil
|
||||
}
|
||||
return upstream.Close()
|
||||
}
|
||||
+1127
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,592 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells")
|
||||
|
||||
type ShellRequest struct {
|
||||
Command string
|
||||
Timeout time.Duration
|
||||
Keyword string
|
||||
UseWaitDefault *bool
|
||||
}
|
||||
|
||||
// NewShell creates the legacy prompt-driven POSIX shell helper.
|
||||
// For raw interactive terminal flows, prefer NewTerminal.
|
||||
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
|
||||
shell = &StarShell{
|
||||
UseWaitDefault: true,
|
||||
WaitTimeout: defaultShellWaitTimeout,
|
||||
isecho: true,
|
||||
iscolor: true,
|
||||
promptToken: defaultShellPromptToken,
|
||||
}
|
||||
|
||||
shell.Session, err = s.NewPTYSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shell.in, err = shell.Session.StdinPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdout, err := shell.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
shell.out = bufio.NewReader(stdout)
|
||||
|
||||
stderr, err := shell.Session.StderrPipe()
|
||||
if err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
shell.er = bufio.NewReader(stderr)
|
||||
|
||||
if err := shell.Session.Shell(); err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go shell.watchSession()
|
||||
shell.gohub()
|
||||
|
||||
if err := shell.configurePrompt(context.Background()); err != nil {
|
||||
_ = shell.Session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shell.Clear()
|
||||
return shell, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) configurePrompt(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return errors.New("shell is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout)
|
||||
defer cancel()
|
||||
ctx = timeoutCtx
|
||||
}
|
||||
|
||||
prompt := s.promptToken + " "
|
||||
setupCommands := []string{
|
||||
"unset PROMPT_COMMAND >/dev/null 2>&1 || true",
|
||||
fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)),
|
||||
}
|
||||
|
||||
s.Clear()
|
||||
for _, cmd := range setupCommands {
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
|
||||
}
|
||||
}
|
||||
|
||||
probeToken := "__STARSSH_POSIX_READY__" + newNonce(6)
|
||||
if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(defaultShellPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
outRaw, errRaw, runErr := s.readState()
|
||||
if runErr != nil {
|
||||
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr)
|
||||
}
|
||||
|
||||
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
|
||||
if strings.Contains(outs, probeToken) {
|
||||
s.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
|
||||
if looksLikeNonPOSIXShellError(errs) {
|
||||
return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) watchSession() {
|
||||
if s == nil || s.Session == nil {
|
||||
return
|
||||
}
|
||||
if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
||||
s.setError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
s.closeOnce.Do(func() {
|
||||
if s.Session == nil {
|
||||
return
|
||||
}
|
||||
closeErr = s.Session.Close()
|
||||
})
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchNoColor(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.iscolor = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchEcho(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isecho = run
|
||||
}
|
||||
|
||||
func (s *StarShell) TrimColor(str string) string {
|
||||
s.rw.RLock()
|
||||
shouldTrim := s.iscolor
|
||||
s.rw.RUnlock()
|
||||
|
||||
if shouldTrim {
|
||||
return SedColor(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchPrint(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isprint = run
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchFunc(run bool) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.isfuncs = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SetFunc(funcs func(string)) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.funcs = funcs
|
||||
}
|
||||
|
||||
func (s *StarShell) Clear() {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
s.outbyte = []byte{}
|
||||
s.errbyte = []byte{}
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
return s.Shell(cmd, sleep)
|
||||
}
|
||||
|
||||
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
|
||||
s.commandMu.Lock()
|
||||
defer s.commandMu.Unlock()
|
||||
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
outRaw, errRaw, runErr := s.GetResult(sleep)
|
||||
if runErr != nil {
|
||||
return "", "", runErr
|
||||
}
|
||||
|
||||
outText := s.TrimColor(strings.TrimSpace(string(outRaw)))
|
||||
|
||||
s.rw.RLock()
|
||||
echoEnabled := s.isecho
|
||||
s.rw.RUnlock()
|
||||
if echoEnabled {
|
||||
outText = stripCommandEchoFromOutput(outText, cmd)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
|
||||
result, err := s.Run(context.Background(), ShellRequest{
|
||||
Command: cmd,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError()
|
||||
}
|
||||
|
||||
func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) {
|
||||
return s.Run(ctx, ShellRequest{
|
||||
Command: command,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("shell is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
s.commandMu.Lock()
|
||||
defer s.commandMu.Unlock()
|
||||
|
||||
if strings.TrimSpace(req.Command) == "" {
|
||||
return nil, errors.New("command is empty")
|
||||
}
|
||||
|
||||
s.rw.RLock()
|
||||
useDefault := s.UseWaitDefault
|
||||
keyword := s.Keyword
|
||||
waitTimeout := s.WaitTimeout
|
||||
promptToken := s.promptToken
|
||||
s.rw.RUnlock()
|
||||
if req.UseWaitDefault != nil {
|
||||
useDefault = *req.UseWaitDefault
|
||||
}
|
||||
if strings.TrimSpace(req.Keyword) != "" {
|
||||
keyword = req.Keyword
|
||||
}
|
||||
if req.Timeout > 0 {
|
||||
waitTimeout = req.Timeout
|
||||
}
|
||||
if !useDefault && keyword == "" {
|
||||
return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set")
|
||||
}
|
||||
if waitTimeout <= 0 {
|
||||
waitTimeout = defaultShellWaitTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
ctx = timeoutCtx
|
||||
}
|
||||
startAt := time.Now()
|
||||
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
|
||||
beginToken, endToken := newCommandTokens()
|
||||
markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken)
|
||||
|
||||
if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.WriteCommand(req.Command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.WriteCommand(markerCmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
outc string
|
||||
errc string
|
||||
exitCode int
|
||||
done bool
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(defaultShellPollInterval):
|
||||
}
|
||||
|
||||
outRaw, errRaw, runErr := s.readState()
|
||||
if runErr != nil {
|
||||
return nil, runErr
|
||||
}
|
||||
|
||||
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
|
||||
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
|
||||
|
||||
s.rw.RLock()
|
||||
useDefault = s.UseWaitDefault
|
||||
keyword = s.Keyword
|
||||
s.rw.RUnlock()
|
||||
|
||||
if useDefault {
|
||||
segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if found {
|
||||
outc = segment
|
||||
errc = errs
|
||||
exitCode = rc
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) {
|
||||
outc = outs
|
||||
errc = errs
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !done {
|
||||
return nil, errors.New("failed to collect shell result")
|
||||
}
|
||||
|
||||
outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken)
|
||||
errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken)
|
||||
outc = stripCommandEchoFromOutput(outc, req.Command)
|
||||
|
||||
stdoutText := strings.TrimSpace(outc)
|
||||
stderrText := strings.TrimSpace(errc)
|
||||
result := &ExecResult{
|
||||
Command: req.Command,
|
||||
Stdout: []byte(stdoutText),
|
||||
Stderr: []byte(stderrText),
|
||||
Combined: combineCommandOutput(stdoutText, stderrText),
|
||||
Duration: time.Since(startAt),
|
||||
}
|
||||
if useDefault {
|
||||
result.ExitCode = exitCode
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) {
|
||||
lines := strings.Split(stdout, "\n")
|
||||
beginLine := -1
|
||||
for i, line := range lines {
|
||||
if strings.TrimSpace(line) == beginToken {
|
||||
beginLine = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if beginLine < 0 {
|
||||
return "", 0, false, nil
|
||||
}
|
||||
|
||||
segment := strings.Join(lines[beginLine+1:], "\n")
|
||||
before, rc, found, err := splitByEndToken(segment, endToken)
|
||||
if err != nil {
|
||||
return "", 0, false, err
|
||||
}
|
||||
if !found {
|
||||
return "", 0, false, nil
|
||||
}
|
||||
return strings.TrimSpace(before), rc, true, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
|
||||
if sleep > 0 {
|
||||
time.Sleep(time.Millisecond * time.Duration(sleep))
|
||||
}
|
||||
return s.readState()
|
||||
}
|
||||
|
||||
func (s *StarShell) WriteCommand(cmd string) error {
|
||||
return s.Write([]byte(cmd + "\n"))
|
||||
}
|
||||
|
||||
func (s *StarShell) Write(bstr []byte) error {
|
||||
if s == nil {
|
||||
return errors.New("shell is nil")
|
||||
}
|
||||
if s.in == nil {
|
||||
return errors.New("shell stdin is not initialized")
|
||||
}
|
||||
if _, _, runErr := s.readState(); runErr != nil {
|
||||
return runErr
|
||||
}
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
_, err := s.in.Write(bstr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarShell) gohub() {
|
||||
if s.er != nil {
|
||||
go s.streamPump(s.er, true)
|
||||
}
|
||||
if s.out != nil {
|
||||
go s.streamPump(s.out, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) {
|
||||
var cache []byte
|
||||
|
||||
for {
|
||||
read, err := reader.ReadByte()
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
s.rw.Lock()
|
||||
if isStderr {
|
||||
s.errbyte = append(s.errbyte, read)
|
||||
} else {
|
||||
s.outbyte = append(s.outbyte, read)
|
||||
}
|
||||
printEnabled := s.isprint
|
||||
funcEnabled := s.isfuncs && s.funcs != nil
|
||||
lineHandler := s.funcs
|
||||
trimColor := s.iscolor
|
||||
s.rw.Unlock()
|
||||
|
||||
if printEnabled {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
|
||||
cache = append(cache, read)
|
||||
if read == '\n' {
|
||||
if funcEnabled {
|
||||
line := strings.TrimSpace(string(cache))
|
||||
if trimColor {
|
||||
line = SedColor(line)
|
||||
}
|
||||
go lineHandler(line)
|
||||
}
|
||||
cache = cache[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) setError(err error) {
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
if s.errors == nil {
|
||||
s.errors = err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) readState() ([]byte, []byte, error) {
|
||||
s.rw.RLock()
|
||||
defer s.rw.RUnlock()
|
||||
|
||||
outCopy := make([]byte, len(s.outbyte))
|
||||
copy(outCopy, s.outbyte)
|
||||
|
||||
errCopy := make([]byte, len(s.errbyte))
|
||||
copy(errCopy, s.errbyte)
|
||||
|
||||
if s.errors != nil {
|
||||
return outCopy, errCopy, s.errors
|
||||
}
|
||||
|
||||
return outCopy, errCopy, nil
|
||||
}
|
||||
|
||||
func stripCommandEchoFromOutput(output string, cmd string) string {
|
||||
if output == "" || cmd == "" {
|
||||
return strings.TrimSpace(output)
|
||||
}
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
cmdLines := strings.Split(cmd, "\n")
|
||||
|
||||
for _, cmdLine := range cmdLines {
|
||||
trimmedCmd := strings.TrimSpace(cmdLine)
|
||||
if trimmedCmd == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, line := range lines {
|
||||
if strings.TrimSpace(line) == trimmedCmd {
|
||||
lines = append(lines[:i], lines[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func looksLikeNonPOSIXShellError(output string) bool {
|
||||
if strings.TrimSpace(output) == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(output)
|
||||
indicators := []string{
|
||||
"is not recognized as an internal or external command",
|
||||
"the term ",
|
||||
"command not found",
|
||||
"unknown command",
|
||||
"not found",
|
||||
"not recognized",
|
||||
}
|
||||
for _, indicator := range indicators {
|
||||
if strings.Contains(lower, indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUid() string {
|
||||
res, _, _ := s.ShellWait("id -u")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetGid() string {
|
||||
res, _, _ := s.ShellWait("id -g")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUser() string {
|
||||
res, _, _ := s.ShellWait("id -un")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarShell) GetGroup() string {
|
||||
res, _, _ := s.ShellWait("id -gn")
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
@@ -1,636 +0,0 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StarSSH struct {
|
||||
Client *ssh.Client
|
||||
PublicKey ssh.PublicKey
|
||||
PubkeyBase64 string
|
||||
Hostname string
|
||||
RemoteAddr net.Addr
|
||||
Banner string
|
||||
LoginInfo LoginInput
|
||||
online bool
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
KeyExchanges []string
|
||||
Ciphers []string
|
||||
MACs []string
|
||||
User string
|
||||
Password string
|
||||
Prikey string
|
||||
PrikeyPwd string
|
||||
Addr string
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
}
|
||||
type StarShell struct {
|
||||
Keyword string
|
||||
UseWaitDefault bool
|
||||
Session *ssh.Session
|
||||
in io.Writer
|
||||
out *bufio.Reader
|
||||
er *bufio.Reader
|
||||
outbyte []byte
|
||||
errbyte []byte
|
||||
lastout int64
|
||||
errors error
|
||||
isprint bool
|
||||
isfuncs bool
|
||||
iscolor bool
|
||||
isecho bool
|
||||
rw sync.RWMutex
|
||||
funcs func(string)
|
||||
}
|
||||
|
||||
func Login(info LoginInput) (*StarSSH, error) {
|
||||
var (
|
||||
auth []ssh.AuthMethod
|
||||
clientConfig *ssh.ClientConfig
|
||||
config ssh.Config
|
||||
err error
|
||||
)
|
||||
sshInfo := new(StarSSH)
|
||||
// get auth method
|
||||
auth = make([]ssh.AuthMethod, 0)
|
||||
if info.Prikey == "" {
|
||||
keyboardInteractiveChallenge := func(
|
||||
user,
|
||||
instruction string,
|
||||
questions []string,
|
||||
echos []bool,
|
||||
) (answers []string, err error) {
|
||||
if len(questions) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
return []string{info.Password}, nil
|
||||
}
|
||||
auth = append(auth, ssh.Password(info.Password))
|
||||
auth = append(auth, ssh.KeyboardInteractive(keyboardInteractiveChallenge))
|
||||
} else {
|
||||
pemBytes := []byte(info.Prikey)
|
||||
var signer ssh.Signer
|
||||
if info.PrikeyPwd == "" {
|
||||
signer, err = ssh.ParsePrivateKey(pemBytes)
|
||||
} else {
|
||||
signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.PublicKeys(signer))
|
||||
}
|
||||
|
||||
if len(info.Ciphers) == 0 {
|
||||
config = ssh.Config{
|
||||
Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc", "chacha20-poly1305@openssh.com"},
|
||||
}
|
||||
} else {
|
||||
config = ssh.Config{
|
||||
Ciphers: info.Ciphers,
|
||||
}
|
||||
}
|
||||
|
||||
if len(info.MACs) != 0 {
|
||||
config.MACs = info.MACs
|
||||
}
|
||||
if len(info.KeyExchanges) != 0 {
|
||||
config.KeyExchanges = info.KeyExchanges
|
||||
}
|
||||
|
||||
if info.Timeout == 0 {
|
||||
info.Timeout = time.Second * 5
|
||||
}
|
||||
hostKeycbfunc := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
|
||||
sshInfo.PublicKey = key
|
||||
sshInfo.RemoteAddr = remote
|
||||
sshInfo.Hostname = hostname
|
||||
if info.HostKeyCallback != nil {
|
||||
return info.HostKeyCallback(hostname, remote, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bannercbfunc := func(banner string) error {
|
||||
sshInfo.Banner = banner
|
||||
if info.BannerCallback != nil {
|
||||
return info.BannerCallback(banner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
clientConfig = &ssh.ClientConfig{
|
||||
User: info.User,
|
||||
Auth: auth,
|
||||
Timeout: info.Timeout,
|
||||
Config: config,
|
||||
HostKeyCallback: hostKeycbfunc,
|
||||
BannerCallback: bannercbfunc,
|
||||
}
|
||||
|
||||
// connet to ssh
|
||||
|
||||
sshInfo.LoginInfo = info
|
||||
sshInfo.Client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", info.Addr, info.Port), clientConfig)
|
||||
if err == nil && sshInfo.PublicKey != nil {
|
||||
sshInfo.online = true
|
||||
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
|
||||
}
|
||||
return sshInfo, err
|
||||
}
|
||||
|
||||
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
|
||||
var info = LoginInput{
|
||||
Addr: host,
|
||||
Port: port,
|
||||
Timeout: timeout,
|
||||
User: user,
|
||||
}
|
||||
if prikeyPath != "" {
|
||||
prikey, err := ioutil.ReadFile(prikeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Prikey = string(prikey)
|
||||
if passwd != "" {
|
||||
info.PrikeyPwd = passwd
|
||||
}
|
||||
} else {
|
||||
info.Password = passwd
|
||||
}
|
||||
return Login(info)
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
|
||||
var outc, errc string = " ", " "
|
||||
s.Clear()
|
||||
defer s.Clear()
|
||||
echo := "echo b7Y85R56TUY6R5UTb612"
|
||||
err := s.WriteCommand(cmd)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
err = s.WriteCommand(echo)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 120)
|
||||
outs := string(s.outbyte)
|
||||
errs := string(s.errbyte)
|
||||
outs = strings.TrimSpace(strings.ReplaceAll(outs, "\r\n", "\n"))
|
||||
errs = strings.TrimSpace(strings.ReplaceAll(errs, "\r\n", "\n"))
|
||||
if len(outs) >= len(cmd+"\n"+echo) && outs[0:len(cmd+"\n"+echo)] == cmd+"\n"+echo {
|
||||
outs = outs[len(cmd+"\n"+echo):]
|
||||
} else if len(outs) >= len(cmd) && outs[0:len(cmd)] == cmd {
|
||||
outs = outs[len(cmd):]
|
||||
}
|
||||
if len(errs) >= len(cmd) && errs[0:len(cmd)] == cmd {
|
||||
errs = errs[len(cmd):]
|
||||
}
|
||||
if s.UseWaitDefault {
|
||||
if strings.Index(string(outs), "b7Y85R56TUY6R5UTb612") >= 0 {
|
||||
list := strings.Split(string(outs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
outc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if strings.Index(string(errs), "b7Y85R56TUY6R5UTb612") >= 0 {
|
||||
list := strings.Split(string(errs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
errc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.Keyword != "" {
|
||||
if strings.Index(string(outs), s.Keyword) >= 0 {
|
||||
list := strings.Split(string(outs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
outc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if strings.Index(string(errs), s.Keyword) >= 0 {
|
||||
list := strings.Split(string(errs), "\n")
|
||||
for _, v := range list {
|
||||
if strings.Index(v, s.Keyword) < 0 && strings.Index(v, "b7Y85R56TUY6R5UTb612") < 0 {
|
||||
errc += v + "\n"
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.TrimColor(strings.TrimSpace(outc)), s.TrimColor(strings.TrimSpace(errc)), err
|
||||
}
|
||||
|
||||
func (s *StarShell) Close() error {
|
||||
return s.Session.Close()
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchNoColor(is bool) {
|
||||
s.iscolor = is
|
||||
}
|
||||
|
||||
func (s *StarShell) SwitchEcho(is bool) {
|
||||
s.isecho = is
|
||||
}
|
||||
|
||||
func (s *StarShell) TrimColor(str string) string {
|
||||
if s.iscolor {
|
||||
return SedColor(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchPrint(run bool) {
|
||||
s.isprint = run
|
||||
}
|
||||
|
||||
/*
|
||||
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
|
||||
*/
|
||||
func (s *StarShell) SwitchFunc(run bool) {
|
||||
s.isfuncs = run
|
||||
}
|
||||
|
||||
func (s *StarShell) SetFunc(funcs func(string)) {
|
||||
s.funcs = funcs
|
||||
}
|
||||
|
||||
func (s *StarShell) Clear() {
|
||||
defer s.rw.Unlock()
|
||||
s.rw.Lock()
|
||||
s.outbyte = []byte{}
|
||||
s.errbyte = []byte{}
|
||||
time.Sleep(time.Millisecond * 15)
|
||||
}
|
||||
|
||||
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
|
||||
defer s.Clear()
|
||||
s.Clear()
|
||||
return s.Shell(cmd, sleep)
|
||||
}
|
||||
|
||||
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
|
||||
if err := s.WriteCommand(cmd); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
tmp1, tmp2, err := s.GetResult(sleep)
|
||||
tmps := s.TrimColor(strings.TrimSpace(string(tmp1)))
|
||||
if s.isecho {
|
||||
n := len(strings.Split(cmd, "\n"))
|
||||
if n == 1 {
|
||||
list := strings.SplitN(tmps, "\n", 2)
|
||||
if len(list) == 2 {
|
||||
tmps = list[1]
|
||||
}
|
||||
} else {
|
||||
list := strings.Split(tmps, "\n")
|
||||
cmds := strings.Split(cmd, "\n")
|
||||
for _, v := range cmds {
|
||||
for k, v2 := range list {
|
||||
if strings.TrimSpace(v2) == strings.TrimSpace(v) {
|
||||
list[k] = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
tmps = ""
|
||||
for _, v := range list {
|
||||
if v != "" {
|
||||
tmps += v + "\n"
|
||||
}
|
||||
}
|
||||
tmps = tmps[0 : len(tmps)-1]
|
||||
}
|
||||
}
|
||||
return tmps, s.TrimColor(strings.TrimSpace(string(tmp2))), err
|
||||
}
|
||||
|
||||
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
|
||||
if s.errors != nil {
|
||||
s.Session.Close()
|
||||
return s.outbyte, s.errbyte, s.errors
|
||||
}
|
||||
if sleep > 0 {
|
||||
time.Sleep(time.Millisecond * time.Duration(sleep))
|
||||
}
|
||||
return s.outbyte, s.errbyte, nil
|
||||
}
|
||||
|
||||
func (s *StarShell) WriteCommand(cmd string) error {
|
||||
return s.Write([]byte(cmd + "\n"))
|
||||
}
|
||||
|
||||
func (s *StarShell) Write(bstr []byte) error {
|
||||
if s.errors != nil {
|
||||
s.Session.Close()
|
||||
return s.errors
|
||||
}
|
||||
_, err := s.in.Write(bstr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StarShell) gohub() {
|
||||
go func() {
|
||||
var cache []byte
|
||||
for {
|
||||
read, err := s.er.ReadByte()
|
||||
if err != nil {
|
||||
s.errors = err
|
||||
return
|
||||
}
|
||||
s.errbyte = append(s.errbyte, read)
|
||||
if s.isprint {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
cache = append(cache, read)
|
||||
if read == '\n' {
|
||||
if s.isfuncs {
|
||||
go s.funcs(s.TrimColor(strings.TrimSpace(string(cache))))
|
||||
cache = []byte{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
var cache []byte
|
||||
for {
|
||||
read, err := s.out.ReadByte()
|
||||
if err != nil {
|
||||
s.errors = err
|
||||
return
|
||||
}
|
||||
s.rw.Lock()
|
||||
s.outbyte = append(s.outbyte, read)
|
||||
cache = append(cache, read)
|
||||
s.rw.Unlock()
|
||||
if read == '\n' {
|
||||
if s.isfuncs {
|
||||
go s.funcs(strings.TrimSpace(string(cache)))
|
||||
cache = []byte{}
|
||||
}
|
||||
}
|
||||
if s.isprint {
|
||||
fmt.Print(string([]byte{read}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarShell) GetUid() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=uid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetGid() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=gid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetUser() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarShell) GetGroup() string {
|
||||
res, _, _ := s.ShellWait(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
|
||||
shell = new(StarShell)
|
||||
shell.Session, err = s.NewSession()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
shell.in, _ = shell.Session.StdinPipe()
|
||||
tmp, _ := shell.Session.StdoutPipe()
|
||||
shell.out = bufio.NewReader(tmp)
|
||||
tmp, _ = shell.Session.StderrPipe()
|
||||
shell.er = bufio.NewReader(tmp)
|
||||
err = shell.Session.Shell()
|
||||
shell.isecho = true
|
||||
go shell.Session.Wait()
|
||||
shell.UseWaitDefault = true
|
||||
shell.WriteCommand("bash")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
shell.WriteCommand("export PS1= ")
|
||||
shell.WriteCommand("export PS2= ")
|
||||
go shell.gohub()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
shell.Clear()
|
||||
shell.Clear()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *StarSSH) Close() error {
|
||||
if s.online {
|
||||
return s.Client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) NewSession() (*ssh.Session, error) {
|
||||
return NewSession(s.Client)
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOne(cmd string) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := newsess.CombinedOutput(cmd)
|
||||
newsess.Close()
|
||||
return strings.TrimSpace(string(data)), err
|
||||
}
|
||||
|
||||
func (s *StarSSH) Exists(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -e "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) GetUid() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=uid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetGid() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=gid\=)\d+"`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetUser() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
func (s *StarSSH) GetGroup() string {
|
||||
res, _ := s.ShellOne(`id | grep -oP "(?<=\().*?(?=\))" | head -n 2 | tail -n 1`)
|
||||
return strings.TrimSpace(res)
|
||||
}
|
||||
|
||||
func (s *StarSSH) IsFile(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -f "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) IsFolder(filepath string) bool {
|
||||
res, _ := s.ShellOne(`echo 1 && [ ! -d "` + filepath + `" ] && echo 2`)
|
||||
if res == "1" {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var bytes, errbytes []byte
|
||||
tmp, _ := newsess.StdoutPipe()
|
||||
reader := bufio.NewReader(tmp)
|
||||
tmp, _ = newsess.StderrPipe()
|
||||
errder := bufio.NewReader(tmp)
|
||||
err = newsess.Start(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c := make(chan int, 1)
|
||||
go newsess.Wait()
|
||||
go func() {
|
||||
for {
|
||||
byt, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Print(string([]byte{byt}))
|
||||
bytes = append(bytes, byt)
|
||||
}
|
||||
c <- 1
|
||||
}()
|
||||
for {
|
||||
byt, err := errder.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Print(string([]byte{byt}))
|
||||
errbytes = append(errbytes, byt)
|
||||
}
|
||||
_ = <-c
|
||||
newsess.Close()
|
||||
if len(errbytes) != 0 {
|
||||
err = errors.New(strings.TrimSpace(string(errbytes)))
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
return strings.TrimSpace(string(bytes)), err
|
||||
}
|
||||
|
||||
func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) {
|
||||
newsess, err := s.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var bytes, errbytes []byte
|
||||
tmp, _ := newsess.StdoutPipe()
|
||||
reader := bufio.NewReader(tmp)
|
||||
tmp, _ = newsess.StderrPipe()
|
||||
errder := bufio.NewReader(tmp)
|
||||
err = newsess.Start(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c := make(chan int, 1)
|
||||
go newsess.Wait()
|
||||
go func() {
|
||||
for {
|
||||
byt, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
callback(string([]byte{byt}))
|
||||
bytes = append(bytes, byt)
|
||||
}
|
||||
c <- 1
|
||||
}()
|
||||
for {
|
||||
byt, err := errder.ReadByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
callback(string([]byte{byt}))
|
||||
errbytes = append(errbytes, byt)
|
||||
}
|
||||
_ = <-c
|
||||
newsess.Close()
|
||||
if len(errbytes) != 0 {
|
||||
err = errors.New(strings.TrimSpace(string(errbytes)))
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
return strings.TrimSpace(string(bytes)), err
|
||||
}
|
||||
|
||||
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
session, err := client.NewSession()
|
||||
return session, err
|
||||
}
|
||||
|
||||
func NewSession(client *ssh.Client) (*ssh.Session, error) {
|
||||
var session *ssh.Session
|
||||
var err error
|
||||
// create session
|
||||
if session, err = client.NewSession(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1, // 还是要强制开启
|
||||
//ssh.IGNCR: 0,
|
||||
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
||||
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm", 500, 250, modes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func SedColor(str string) string {
|
||||
reg := regexp.MustCompile(`\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]`)
|
||||
//fmt.Println("regexp:", reg.Match([]byte(str)))
|
||||
return string(reg.ReplaceAll([]byte(str), []byte("")))
|
||||
}
|
||||
@@ -0,0 +1,668 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
sshagent "golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||
var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth")
|
||||
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||
|
||||
type sshAgentTimeouts struct {
|
||||
Dial time.Duration
|
||||
Operation time.Duration
|
||||
Forward time.Duration
|
||||
Endpoint string
|
||||
Resolved resolvedSSHAgentEndpoint
|
||||
Debug SSHAgentDebugFunc
|
||||
SkipFingerprints map[string]struct{}
|
||||
SignFailure func(ssh.PublicKey, error)
|
||||
}
|
||||
|
||||
type sshAgentAuthAttempt struct {
|
||||
mu sync.Mutex
|
||||
skipFingerprints map[string]struct{}
|
||||
retryRequested bool
|
||||
}
|
||||
|
||||
var defaultAuthOrder = []AuthMethodKind{
|
||||
AuthMethodSSHAgent,
|
||||
AuthMethodPrivateKey,
|
||||
AuthMethodPassword,
|
||||
AuthMethodKeyboardInteractive,
|
||||
}
|
||||
|
||||
func effectiveSSHAgentTimeout(info LoginInput) time.Duration {
|
||||
switch {
|
||||
case info.SSHAgentTimeout < 0:
|
||||
return 0
|
||||
case info.SSHAgentTimeout > 0:
|
||||
return info.SSHAgentTimeout
|
||||
default:
|
||||
return defaultSSHAgentTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts {
|
||||
return sshAgentTimeouts{
|
||||
Dial: effectiveDialTimeout(info),
|
||||
Operation: effectiveSSHAgentTimeout(info),
|
||||
Forward: effectiveSSHAgentForwardTimeout(info),
|
||||
Endpoint: info.IdentityAgent,
|
||||
Debug: info.SSHAgentDebug,
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration {
|
||||
if info.SSHAgentForwardTimeout > 0 {
|
||||
return info.SSHAgentForwardTimeout
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||
return buildAuthMethodsWithAgentAttempt(info, nil)
|
||||
}
|
||||
|
||||
func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) {
|
||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
auth := make([]ssh.AuthMethod, 0, len(order))
|
||||
var agentErr error
|
||||
var cleanupFuncs []func()
|
||||
|
||||
for _, methodKind := range order {
|
||||
switch methodKind {
|
||||
case AuthMethodPrivateKey:
|
||||
method, err := buildPrivateKeyAuthMethod(info, agentAttempt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodPassword:
|
||||
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt)
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodKeyboardInteractive:
|
||||
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt)
|
||||
if method != nil {
|
||||
auth = append(auth, method)
|
||||
}
|
||||
case AuthMethodSSHAgent:
|
||||
if info.DisableSSHAgent {
|
||||
continue
|
||||
}
|
||||
timeouts := effectiveSSHAgentTimeouts(info)
|
||||
if agentAttempt != nil {
|
||||
timeouts.SkipFingerprints = agentAttempt.skipSnapshot()
|
||||
timeouts.SignFailure = agentAttempt.recordSignFailure
|
||||
}
|
||||
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts)
|
||||
if err != nil {
|
||||
agentErr = err
|
||||
continue
|
||||
}
|
||||
if agentMethod != nil {
|
||||
auth = append(auth, agentMethod)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanupFuncs = append(cleanupFuncs, cleanup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(auth) == 0 {
|
||||
if agentErr != nil {
|
||||
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
|
||||
}
|
||||
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
|
||||
}
|
||||
|
||||
return auth, composeCleanup(cleanupFuncs...), nil
|
||||
}
|
||||
|
||||
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
|
||||
if len(order) == 0 {
|
||||
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
|
||||
}
|
||||
|
||||
normalized := make([]AuthMethodKind, 0, len(order))
|
||||
seen := make(map[AuthMethodKind]struct{}, len(order))
|
||||
for _, raw := range order {
|
||||
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
|
||||
if kind == "" {
|
||||
return nil, errors.New("auth order contains an empty auth method")
|
||||
}
|
||||
if !isSupportedAuthMethodKind(kind) {
|
||||
return nil, fmt.Errorf("unsupported auth method %q", raw)
|
||||
}
|
||||
if _, exists := seen[kind]; exists {
|
||||
continue
|
||||
}
|
||||
seen[kind] = struct{}{}
|
||||
normalized = append(normalized, kind)
|
||||
}
|
||||
|
||||
if len(normalized) == 0 {
|
||||
return nil, errors.New("auth order is empty")
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
|
||||
switch kind {
|
||||
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool {
|
||||
if info.DisableSSHAgent {
|
||||
return false
|
||||
}
|
||||
for _, methodKind := range order {
|
||||
if methodKind == AuthMethodSSHAgent {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) {
|
||||
if strings.TrimSpace(info.Prikey) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pemBytes := []byte(info.Prikey)
|
||||
if info.PrikeyPwd == "" {
|
||||
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
|
||||
}
|
||||
|
||||
func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) {
|
||||
return func() ([]ssh.Signer, error) {
|
||||
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ssh.Signer{signer}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod {
|
||||
if password == "" && callback == nil {
|
||||
return nil
|
||||
}
|
||||
return ssh.PasswordCallback(func() (string, error) {
|
||||
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if password != "" {
|
||||
return password, nil
|
||||
}
|
||||
return callback()
|
||||
})
|
||||
}
|
||||
|
||||
func buildKeyboardInteractiveAuthMethod(
|
||||
password string,
|
||||
passwordCallback func() (string, error),
|
||||
challenge ssh.KeyboardInteractiveChallenge,
|
||||
agentAttempt *sshAgentAuthAttempt,
|
||||
) ssh.AuthMethod {
|
||||
if challenge != nil {
|
||||
return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return challenge(user, instruction, questions, echos)
|
||||
})
|
||||
}
|
||||
if password == "" && passwordCallback == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(questions) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
answer := password
|
||||
if answer == "" {
|
||||
var err error
|
||||
answer, err = passwordCallback()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = answer
|
||||
}
|
||||
return answers, nil
|
||||
}
|
||||
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
|
||||
}
|
||||
|
||||
func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
||||
conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSSHAgentUnavailable) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation)
|
||||
|
||||
started := time.Now()
|
||||
signers, err := sshagent.NewClient(conn).Signers()
|
||||
err = normalizeSSHAgentError(err)
|
||||
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
|
||||
Step: "auth",
|
||||
Source: resolved.Source,
|
||||
Endpoint: resolved.Endpoint,
|
||||
Network: resolved.Network,
|
||||
Phase: "list",
|
||||
Status: debugStatus(err),
|
||||
Duration: time.Since(started),
|
||||
KeyCount: len(signers),
|
||||
Err: err,
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(signers) == 0 {
|
||||
_ = conn.Close()
|
||||
return nil, nil, errors.New("ssh-agent has no loaded keys")
|
||||
}
|
||||
|
||||
timeouts.Resolved = resolved
|
||||
orderedSigners := orderSSHAgentSigners(signers)
|
||||
filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts)
|
||||
if len(filteredSigners) == 0 {
|
||||
_ = conn.Close()
|
||||
return nil, nil, errors.New("ssh-agent has no usable keys")
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(filteredSigners...), func() {
|
||||
_ = conn.Close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer {
|
||||
type orderedSigner struct {
|
||||
signer ssh.Signer
|
||||
index int
|
||||
score int
|
||||
comment string
|
||||
}
|
||||
|
||||
ordered := make([]orderedSigner, 0, len(signers))
|
||||
for index, signer := range signers {
|
||||
if signer == nil || signer.PublicKey() == nil {
|
||||
continue
|
||||
}
|
||||
ordered = append(ordered, orderedSigner{
|
||||
signer: signer,
|
||||
index: index,
|
||||
score: sshAgentSignerPriority(signer),
|
||||
comment: sshAgentSignerComment(signer),
|
||||
})
|
||||
}
|
||||
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
if ordered[i].score != ordered[j].score {
|
||||
return ordered[i].score > ordered[j].score
|
||||
}
|
||||
return ordered[i].index < ordered[j].index
|
||||
})
|
||||
|
||||
result := make([]ssh.Signer, 0, len(ordered))
|
||||
for _, item := range ordered {
|
||||
result = append(result, item.signer)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func sshAgentSignerComment(signer ssh.Signer) string {
|
||||
if signer == nil {
|
||||
return ""
|
||||
}
|
||||
if key, ok := signer.PublicKey().(*sshagent.Key); ok {
|
||||
return key.Comment
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func sshAgentSignerPriority(signer ssh.Signer) int {
|
||||
comment := strings.TrimSpace(sshAgentSignerComment(signer))
|
||||
if comment == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
score := 0
|
||||
if priority, ok := parseSSHAgentSignerPriority(comment); ok {
|
||||
score += 100000 + priority*1000
|
||||
}
|
||||
|
||||
lower := strings.ToLower(comment)
|
||||
if strings.Contains(lower, "current") {
|
||||
score += 400
|
||||
}
|
||||
if strings.Contains(lower, "cardno:") {
|
||||
score += 300
|
||||
}
|
||||
if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") {
|
||||
score += 100
|
||||
}
|
||||
if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") {
|
||||
score += 50
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
func parseSSHAgentSignerPriority(comment string) (int, bool) {
|
||||
lower := strings.ToLower(comment)
|
||||
index := strings.Index(lower, "priority=")
|
||||
if index < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(comment[index+len("priority="):])
|
||||
if value == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
end := 0
|
||||
for end < len(value) {
|
||||
ch := value[end]
|
||||
if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') {
|
||||
end++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if end == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
priority, err := strconv.Atoi(value[:end])
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return priority, true
|
||||
}
|
||||
|
||||
func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer {
|
||||
filteredSigners := make([]ssh.Signer, 0, len(signers))
|
||||
for _, signer := range signers {
|
||||
if signer == nil {
|
||||
continue
|
||||
}
|
||||
publicKey := signer.PublicKey()
|
||||
if publicKey == nil {
|
||||
continue
|
||||
}
|
||||
if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip {
|
||||
continue
|
||||
}
|
||||
if timeouts.SignFailure == nil && timeouts.Debug == nil {
|
||||
filteredSigners = append(filteredSigners, signer)
|
||||
continue
|
||||
}
|
||||
filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{
|
||||
Resolved: timeouts.Resolved,
|
||||
Debug: timeouts.Debug,
|
||||
SignFailure: timeouts.SignFailure,
|
||||
}))
|
||||
}
|
||||
return filteredSigners
|
||||
}
|
||||
|
||||
func newSSHAgentAuthAttempt() *sshAgentAuthAttempt {
|
||||
return &sshAgentAuthAttempt{
|
||||
skipFingerprints: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *sshAgentAuthAttempt) begin() {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.retryRequested = false
|
||||
}
|
||||
|
||||
func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if len(a.skipFingerprints) == 0 {
|
||||
return nil
|
||||
}
|
||||
snapshot := make(map[string]struct{}, len(a.skipFingerprints))
|
||||
for fingerprint := range a.skipFingerprints {
|
||||
snapshot[fingerprint] = struct{}{}
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) {
|
||||
_ = err
|
||||
if a == nil || publicKey == nil {
|
||||
return
|
||||
}
|
||||
a.skipFingerprint(ssh.FingerprintSHA256(publicKey))
|
||||
}
|
||||
|
||||
func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.retryRequested = true
|
||||
if fingerprint != "" {
|
||||
a.skipFingerprints[fingerprint] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *sshAgentAuthAttempt) shouldRetry() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.retryRequested
|
||||
}
|
||||
|
||||
func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error {
|
||||
if agentAttempt != nil && agentAttempt.shouldRetry() {
|
||||
return errRetrySSHAgentAuth
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshAgentRetrySigner struct {
|
||||
signer ssh.Signer
|
||||
publicKey ssh.PublicKey
|
||||
options sshAgentSignerOptions
|
||||
}
|
||||
|
||||
type sshAgentRetryAlgorithmSigner struct {
|
||||
sshAgentRetrySigner
|
||||
algorithmSigner ssh.AlgorithmSigner
|
||||
}
|
||||
|
||||
type sshAgentRetryMultiAlgorithmSigner struct {
|
||||
sshAgentRetryAlgorithmSigner
|
||||
multiAlgorithmSigner ssh.MultiAlgorithmSigner
|
||||
}
|
||||
|
||||
type sshAgentSignerOptions struct {
|
||||
Resolved resolvedSSHAgentEndpoint
|
||||
Debug SSHAgentDebugFunc
|
||||
SignFailure func(ssh.PublicKey, error)
|
||||
}
|
||||
|
||||
func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer {
|
||||
return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure})
|
||||
}
|
||||
|
||||
func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer {
|
||||
publicKey := signer.PublicKey()
|
||||
base := sshAgentRetrySigner{
|
||||
signer: signer,
|
||||
publicKey: publicKey,
|
||||
options: options,
|
||||
}
|
||||
if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok {
|
||||
return &sshAgentRetryMultiAlgorithmSigner{
|
||||
sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{
|
||||
sshAgentRetrySigner: base,
|
||||
algorithmSigner: multiAlgorithmSigner,
|
||||
},
|
||||
multiAlgorithmSigner: multiAlgorithmSigner,
|
||||
}
|
||||
}
|
||||
if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok {
|
||||
return &sshAgentRetryAlgorithmSigner{
|
||||
sshAgentRetrySigner: base,
|
||||
algorithmSigner: algorithmSigner,
|
||||
}
|
||||
}
|
||||
return &base
|
||||
}
|
||||
|
||||
func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey {
|
||||
return s.publicKey
|
||||
}
|
||||
|
||||
func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||
started := time.Now()
|
||||
signature, err := s.signer.Sign(rand, data)
|
||||
return signature, s.finishSign(started, err)
|
||||
}
|
||||
|
||||
func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error {
|
||||
err = normalizeSSHAgentError(err)
|
||||
s.logSignDebug(started, err)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if s.options.SignFailure != nil {
|
||||
s.options.SignFailure(s.publicKey, err)
|
||||
return wrapSSHAgentSignError(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) {
|
||||
if s == nil || s.options.Debug == nil {
|
||||
return
|
||||
}
|
||||
logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{
|
||||
Step: "auth",
|
||||
Source: s.options.Resolved.Source,
|
||||
Endpoint: s.options.Resolved.Endpoint,
|
||||
Network: s.options.Resolved.Network,
|
||||
Phase: "sign",
|
||||
Status: debugStatus(err),
|
||||
Duration: time.Since(started),
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil)
|
||||
started := time.Now()
|
||||
signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
||||
return signature, s.finishSign(started, err)
|
||||
}
|
||||
|
||||
func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string {
|
||||
return s.multiAlgorithmSigner.Algorithms()
|
||||
}
|
||||
|
||||
func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
||||
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms())
|
||||
started := time.Now()
|
||||
signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
||||
return signature, s.finishSign(started, err)
|
||||
}
|
||||
|
||||
func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string {
|
||||
if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA {
|
||||
return requested
|
||||
}
|
||||
if len(algorithms) == 0 {
|
||||
return ssh.KeyAlgoRSASHA256
|
||||
}
|
||||
for _, algorithm := range algorithms {
|
||||
if algorithm == ssh.KeyAlgoRSA {
|
||||
break
|
||||
}
|
||||
if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 {
|
||||
return algorithm
|
||||
}
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
func wrapSSHAgentSignError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err))
|
||||
}
|
||||
|
||||
func composeCleanup(funcs ...func()) func() {
|
||||
if len(funcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return func() {
|
||||
for i := len(funcs) - 1; i >= 0; i-- {
|
||||
if funcs[i] != nil {
|
||||
funcs[i]()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
|
||||
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
|
||||
|
||||
type sshAgentDialOptions struct {
|
||||
Endpoint string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type resolvedSSHAgentEndpoint struct {
|
||||
Endpoint string
|
||||
Source string
|
||||
Network string
|
||||
}
|
||||
|
||||
type deadlineAgentConn struct {
|
||||
net.Conn
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
|
||||
endpoint := strings.TrimSpace(options.Endpoint)
|
||||
if endpoint != "" {
|
||||
return resolvedSSHAgentEndpoint{
|
||||
Endpoint: endpoint,
|
||||
Source: "identity-agent",
|
||||
Network: defaultSSHAgentNetwork(endpoint),
|
||||
}, nil
|
||||
}
|
||||
|
||||
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
||||
if endpoint != "" {
|
||||
return resolvedSSHAgentEndpoint{
|
||||
Endpoint: endpoint,
|
||||
Source: "SSH_AUTH_SOCK",
|
||||
Network: defaultSSHAgentNetwork(endpoint),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return defaultSSHAgentEndpoint()
|
||||
}
|
||||
|
||||
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||||
resolved, err := resolveSSHAgentEndpoint(options)
|
||||
if err != nil {
|
||||
return nil, resolvedSSHAgentEndpoint{}, err
|
||||
}
|
||||
|
||||
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
|
||||
if isTimeoutError(err) {
|
||||
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, resolved, err
|
||||
}
|
||||
return conn, resolved, nil
|
||||
}
|
||||
|
||||
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||||
options := sshAgentDialOptions{
|
||||
Endpoint: timeouts.Endpoint,
|
||||
Timeout: timeouts.Dial,
|
||||
}
|
||||
started := time.Now()
|
||||
conn, resolved, err := dialSSHAgent(options)
|
||||
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
|
||||
Step: step,
|
||||
Source: resolved.Source,
|
||||
Endpoint: resolved.Endpoint,
|
||||
Network: resolved.Network,
|
||||
Phase: "dial",
|
||||
Status: debugStatus(err),
|
||||
Duration: time.Since(started),
|
||||
Err: err,
|
||||
})
|
||||
return conn, resolved, err
|
||||
}
|
||||
|
||||
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
|
||||
if debug == nil {
|
||||
return
|
||||
}
|
||||
debug(event)
|
||||
}
|
||||
|
||||
func debugStatus(err error) string {
|
||||
if err != nil {
|
||||
return "error"
|
||||
}
|
||||
return "ok"
|
||||
}
|
||||
|
||||
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
|
||||
if conn == nil || timeout <= 0 {
|
||||
return conn
|
||||
}
|
||||
return &deadlineAgentConn{Conn: conn, timeout: timeout}
|
||||
}
|
||||
|
||||
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
|
||||
c.setDeadline()
|
||||
n, err := c.Conn.Read(p)
|
||||
return n, wrapSSHAgentConnError(err)
|
||||
}
|
||||
|
||||
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
|
||||
c.setDeadline()
|
||||
n, err := c.Conn.Write(p)
|
||||
return n, wrapSSHAgentConnError(err)
|
||||
}
|
||||
|
||||
func (c *deadlineAgentConn) setDeadline() {
|
||||
if c == nil || c.timeout <= 0 || c.Conn == nil {
|
||||
return
|
||||
}
|
||||
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
|
||||
}
|
||||
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
func wrapSSHAgentConnError(err error) error {
|
||||
if isTimeoutError(err) {
|
||||
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeSSHAgentError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrSSHAgentTimeout) {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
|
||||
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//go:build !windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
|
||||
return resolvedSSHAgentEndpoint{}, errSSHAgentUnavailable
|
||||
}
|
||||
|
||||
func defaultSSHAgentNetwork(endpoint string) string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||
agentSock := resolved.Endpoint
|
||||
if timeout > 0 {
|
||||
return net.DialTimeout("unix", agentSock, timeout)
|
||||
}
|
||||
return net.Dial("unix", agentSock)
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
//go:build windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
|
||||
|
||||
var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file")
|
||||
|
||||
type gpgSocketInfo struct {
|
||||
port uint16
|
||||
nonce []byte
|
||||
cygwin bool
|
||||
}
|
||||
|
||||
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
|
||||
return resolvedSSHAgentEndpoint{
|
||||
Endpoint: defaultWindowsSSHAgentPipe,
|
||||
Source: "platform-default",
|
||||
Network: "windows-pipe",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func defaultSSHAgentNetwork(endpoint string) string {
|
||||
if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
|
||||
return "windows-pipe"
|
||||
}
|
||||
if isAgentSSHSocketPath(endpoint) {
|
||||
return "gpg-socket"
|
||||
}
|
||||
return "unix"
|
||||
}
|
||||
|
||||
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||
if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok {
|
||||
return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default")
|
||||
}
|
||||
if isAgentSSHSocketPath(resolved.Endpoint) {
|
||||
return dialWindowsGPGSocketFile(resolved.Endpoint, timeout)
|
||||
}
|
||||
return dialWindowsUnixAgent(resolved.Endpoint, timeout)
|
||||
}
|
||||
|
||||
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
if timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound)
|
||||
}
|
||||
|
||||
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(endpoint)
|
||||
if trimmed == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalized := trimmed
|
||||
if strings.HasPrefix(normalized, "//./pipe/") {
|
||||
normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/")
|
||||
}
|
||||
if strings.HasPrefix(normalized, `\\.\pipe\`) {
|
||||
return normalized, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isWindowsPipeUnavailable(err error) bool {
|
||||
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
|
||||
}
|
||||
|
||||
func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) {
|
||||
if timeout > 0 {
|
||||
return net.DialTimeout("unix", endpoint, timeout)
|
||||
}
|
||||
return net.Dial("unix", endpoint)
|
||||
}
|
||||
|
||||
func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
if timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
defer cancel()
|
||||
return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0)
|
||||
}
|
||||
|
||||
func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("gpg agent endpoint is empty")
|
||||
}
|
||||
if depth > 8 {
|
||||
return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if target, ok := parseGPGAssuanSocketRedirect(data); ok {
|
||||
target = resolveGPGSocketRedirectTarget(path, target)
|
||||
if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok {
|
||||
return dialWindowsNamedPipeContext(ctx, pipePath, false)
|
||||
}
|
||||
return dialWindowsGPGSocketFileDepth(ctx, target, depth+1)
|
||||
}
|
||||
|
||||
info, err := parseGPGSocketInfo(path, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialWindowsGPGSocketInfo(ctx, info)
|
||||
}
|
||||
|
||||
func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if _, err := conn.Write(info.nonce); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if info.cygwin {
|
||||
var nonce [16]byte
|
||||
if _, err := io.ReadFull(conn, nonce[:]); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
var credential [8]byte
|
||||
binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid()))
|
||||
if _, err := conn.Write(credential[:]); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.ReadFull(conn, credential[:]); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func resolveGPGSocketRedirectTarget(source string, target string) string {
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" || filepath.IsAbs(target) {
|
||||
return target
|
||||
}
|
||||
if _, ok := normalizeWindowsSSHAgentPipe(target); ok {
|
||||
return target
|
||||
}
|
||||
return filepath.Join(filepath.Dir(source), target)
|
||||
}
|
||||
|
||||
func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) {
|
||||
if info, ok := parseGPGAssuanSocketInfo(data); ok {
|
||||
return info, nil
|
||||
}
|
||||
if info, ok := parseGPGCygwinSocketInfo(data); ok {
|
||||
return info, nil
|
||||
}
|
||||
return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path)
|
||||
}
|
||||
|
||||
func parseGPGAssuanSocketRedirect(data []byte) (string, bool) {
|
||||
text := strings.ReplaceAll(string(data), "\r\n", "\n")
|
||||
text = strings.TrimSuffix(text, "\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
if len(lines) != 2 || lines[0] != "%Assuan%" {
|
||||
return "", false
|
||||
}
|
||||
target, ok := strings.CutPrefix(lines[1], "socket=")
|
||||
if !ok || strings.TrimSpace(target) == "" {
|
||||
return "", false
|
||||
}
|
||||
return os.ExpandEnv(target), true
|
||||
}
|
||||
|
||||
func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) {
|
||||
newline := bytes.IndexByte(data, '\n')
|
||||
if newline <= 0 || len(data)-newline-1 != 16 {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16)
|
||||
if err != nil || port64 == 0 {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
nonce := make([]byte, 16)
|
||||
copy(nonce, data[newline+1:])
|
||||
return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true
|
||||
}
|
||||
|
||||
func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) {
|
||||
if !bytes.HasPrefix(data, []byte("!<socket >")) {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00"))
|
||||
if len(fields) != 3 || fields[1] != "s" {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
port64, err := strconv.ParseUint(fields[0], 10, 16)
|
||||
if err != nil || port64 == 0 {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
hexParts := strings.Split(fields[2], "-")
|
||||
if len(hexParts) != 4 {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
nonce := make([]byte, 0, 16)
|
||||
for _, part := range hexParts {
|
||||
if len(part) != 8 {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
value, err := strconv.ParseUint(part, 16, 32)
|
||||
if err != nil {
|
||||
return gpgSocketInfo{}, false
|
||||
}
|
||||
var chunk [4]byte
|
||||
binary.LittleEndian.PutUint32(chunk[:], uint32(value))
|
||||
nonce = append(nonce, chunk[:]...)
|
||||
}
|
||||
return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true
|
||||
}
|
||||
|
||||
func isAgentSSHSocketPath(endpoint string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(endpoint))
|
||||
return strings.HasSuffix(normalized, "s.gpg-agent.ssh")
|
||||
}
|
||||
|
||||
func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
conn, err := winio.DialPipeContext(ctx, path)
|
||||
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
|
||||
return nil, errSSHAgentUnavailable
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
//go:build windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseGPGAssuanSocketInfo(t *testing.T) {
|
||||
info, ok := parseGPGAssuanSocketInfo([]byte("7247\n0123456789abcdef"))
|
||||
if !ok {
|
||||
t.Fatal("expected Assuan socket info to parse")
|
||||
}
|
||||
if info.port != 7247 || string(info.nonce) != "0123456789abcdef" || info.cygwin {
|
||||
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGPGCygwinSocketInfo(t *testing.T) {
|
||||
info, ok := parseGPGCygwinSocketInfo([]byte("!<socket >7247 s 00000001-02030405-06070809-0a0b0c0d\x00"))
|
||||
if !ok {
|
||||
t.Fatal("expected Cygwin socket info to parse")
|
||||
}
|
||||
want := []byte{1, 0, 0, 0, 5, 4, 3, 2, 9, 8, 7, 6, 13, 12, 11, 10}
|
||||
if info.port != 7247 || string(info.nonce) != string(want) || !info.cygwin {
|
||||
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGPGAssuanSocketRedirect(t *testing.T) {
|
||||
t.Setenv("STARSSH_TEST_PIPE", `\\.\pipe\openssh-ssh-agent`)
|
||||
target, ok := parseGPGAssuanSocketRedirect([]byte("%Assuan%\r\nsocket=${STARSSH_TEST_PIPE}\r\n"))
|
||||
if !ok {
|
||||
t.Fatal("expected Assuan redirect to parse")
|
||||
}
|
||||
if target != `\\.\pipe\openssh-ssh-agent` {
|
||||
t.Fatalf("target=%q", target)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInvalidAgentSSHSocketReturnsGPGSocketError(t *testing.T) {
|
||||
path := t.TempDir() + "/S.gpg-agent.ssh"
|
||||
if err := os.WriteFile(path, []byte("not a socket info file"), 0o600); err != nil {
|
||||
t.Fatalf("write socket file: %v", err)
|
||||
}
|
||||
|
||||
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||
Endpoint: path,
|
||||
Source: "SSH_AUTH_SOCK",
|
||||
Network: defaultSSHAgentNetwork(path),
|
||||
}, 0)
|
||||
if !errors.Is(err, errInvalidGPGSocketInfo) {
|
||||
t.Fatalf("err=%v want errInvalidGPGSocketInfo", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingAgentSSHSocketReturnsReadError(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||
|
||||
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||
Endpoint: path,
|
||||
Source: "identity-agent",
|
||||
Network: defaultSSHAgentNetwork(path),
|
||||
}, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing GPG socket file error")
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("err=%v want os.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnreadableAgentSSHSocketReturnsReadError(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||
if err := os.Mkdir(path, 0o700); err != nil {
|
||||
t.Fatalf("mkdir socket path: %v", err)
|
||||
}
|
||||
|
||||
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
|
||||
Endpoint: path,
|
||||
Source: "identity-agent",
|
||||
Network: defaultSSHAgentNetwork(path),
|
||||
}, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected unreadable GPG socket file error")
|
||||
}
|
||||
if errors.Is(err, errInvalidGPGSocketInfo) {
|
||||
t.Fatalf("err=%v should expose read failure before parse", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialWindowsGPGSocketFilePerformsNonceHandshake(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen tcp: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
type handshakeResult struct {
|
||||
nonce []byte
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan handshakeResult, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
resultCh <- handshakeResult{err: err}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
nonce := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, nonce); err != nil {
|
||||
resultCh <- handshakeResult{err: err}
|
||||
return
|
||||
}
|
||||
resultCh <- handshakeResult{nonce: append([]byte(nil), nonce...)}
|
||||
}()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
|
||||
if err := os.WriteFile(socketPath, []byte(strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)+"\n0123456789abcdef"), 0o600); err != nil {
|
||||
t.Fatalf("write socket file: %v", err)
|
||||
}
|
||||
|
||||
conn, err := dialWindowsGPGSocketFile(socketPath, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dialWindowsGPGSocketFile: %v", err)
|
||||
}
|
||||
_ = conn.Close()
|
||||
|
||||
var result handshakeResult
|
||||
select {
|
||||
case result = <-resultCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("listener did not accept GPG socket connection")
|
||||
}
|
||||
if result.err != nil {
|
||||
t.Fatalf("listener handshake error: %v", result.err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result.nonce, []byte("0123456789abcdef")) {
|
||||
t.Fatalf("nonce=%q", result.nonce)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errSSHClientClosing = errors.New("ssh client is closing")
|
||||
|
||||
type sshClientRequester interface {
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
var closeSSHClient = func(client sshClientRequester) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return client.Close()
|
||||
}
|
||||
|
||||
func (s *StarSSH) snapshotSSHClient() *ssh.Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
if s.closing.Load() {
|
||||
return nil, errSSHClientClosing
|
||||
}
|
||||
client := s.snapshotSSHClient()
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
if s.closing.Load() {
|
||||
return nil, errSSHClientClosing
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
s.Client = client
|
||||
s.upstream = upstream
|
||||
s.online = client != nil
|
||||
s.closing.Store(false)
|
||||
}
|
||||
|
||||
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
|
||||
client := s.Client
|
||||
upstream := s.upstream
|
||||
s.Client = nil
|
||||
s.upstream = nil
|
||||
s.online = false
|
||||
return client, upstream
|
||||
}
|
||||
|
||||
func (s *StarSSH) takeKeepaliveHandles() (chan struct{}, chan struct{}) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.keepaliveMu.Lock()
|
||||
defer s.keepaliveMu.Unlock()
|
||||
|
||||
stop := s.keepaliveStop
|
||||
done := s.keepaliveDone
|
||||
s.keepaliveStop = nil
|
||||
s.keepaliveDone = nil
|
||||
return stop, done
|
||||
}
|
||||
|
||||
func (s *StarSSH) closeTransport(waitKeepalive bool) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.closing.Store(true)
|
||||
_ = s.closeReusableSFTPClient()
|
||||
agentForwarder := s.takeAgentForwarder()
|
||||
|
||||
client, upstream := s.detachTransport()
|
||||
stop, done := s.takeKeepaliveHandles()
|
||||
if stop != nil {
|
||||
close(stop)
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
if agentForwarder != nil {
|
||||
closeErr = normalizeAlreadyClosedError(agentForwarder.Close())
|
||||
}
|
||||
if client != nil {
|
||||
if err := normalizeAlreadyClosedError(closeSSHClient(client)); closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
if waitKeepalive && done != nil {
|
||||
<-done
|
||||
}
|
||||
if upstreamErr := closeUpstream(upstream); closeErr == nil {
|
||||
closeErr = upstreamErr
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (s *StarSSH) canAttachAgentForwarder(client *ssh.Client) bool {
|
||||
if s == nil || client == nil || s.closing.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
return !s.closing.Load() && s.Client == client
|
||||
}
|
||||
+431
@@ -0,0 +1,431 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) {
|
||||
session, err := s.NewPTYSession(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TerminalSession{
|
||||
Session: session,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
runDone: make(chan struct{}),
|
||||
waitDone: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.attachMu.Lock()
|
||||
defer t.attachMu.Unlock()
|
||||
t.in = stdin
|
||||
t.out = stdout
|
||||
t.errOut = stderr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StdinWriter() io.Writer {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stdin
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StdoutReader() io.Reader {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stdout
|
||||
}
|
||||
|
||||
func (t *TerminalSession) StderrReader() io.Reader {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.stderr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Write(data []byte) (int, error) {
|
||||
if t == nil || t.stdin == nil {
|
||||
return 0, errors.New("terminal stdin is not initialized")
|
||||
}
|
||||
return t.stdin.Write(data)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) SendControl(control TerminalControl) error {
|
||||
_, err := t.Write([]byte{byte(control)})
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Interrupt() error {
|
||||
return t.SendControl(TerminalControlInterrupt)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Signal(sig ssh.Signal) error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
if sig == "" {
|
||||
return errors.New("signal is empty")
|
||||
}
|
||||
return t.Session.Signal(sig)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Run(ctx context.Context) error {
|
||||
if t == nil {
|
||||
return errors.New("terminal session is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
t.runOnce.Do(func() {
|
||||
t.runErr = t.run(ctx)
|
||||
close(t.runDone)
|
||||
})
|
||||
|
||||
<-t.runDone
|
||||
return t.runErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) run(ctx context.Context) error {
|
||||
if t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
|
||||
t.attachMu.RLock()
|
||||
in := t.in
|
||||
out := t.out
|
||||
errOut := t.errOut
|
||||
t.attachMu.RUnlock()
|
||||
if out == nil {
|
||||
out = io.Discard
|
||||
}
|
||||
if errOut == nil {
|
||||
errOut = out
|
||||
}
|
||||
|
||||
inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancelInput()
|
||||
|
||||
var copyWG sync.WaitGroup
|
||||
doneCopy := make(chan struct{})
|
||||
copyWG.Add(2)
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if t.stdout != nil {
|
||||
_, _ = io.Copy(out, t.stdout)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if t.stderr != nil {
|
||||
_, _ = io.Copy(errOut, t.stderr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
copyWG.Wait()
|
||||
close(doneCopy)
|
||||
}()
|
||||
|
||||
var doneInput chan struct{}
|
||||
if inputReader != nil && t.stdin != nil {
|
||||
doneInput = make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneInput)
|
||||
_, _ = io.Copy(t.stdin, inputReader)
|
||||
_ = t.stdin.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
waitInputPump := func() {
|
||||
if doneInput == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-doneInput:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if inputCancelable {
|
||||
cancelInput()
|
||||
<-doneInput
|
||||
}
|
||||
}
|
||||
|
||||
type waitResult struct {
|
||||
info TerminalExitInfo
|
||||
err error
|
||||
}
|
||||
|
||||
waitCh := make(chan waitResult, 1)
|
||||
go func() {
|
||||
info, err := t.WaitResult()
|
||||
waitCh <- waitResult{info: info, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-waitCh:
|
||||
waitInputPump()
|
||||
<-doneCopy
|
||||
if result.err != nil {
|
||||
return result.err
|
||||
}
|
||||
return result.info.CommandError()
|
||||
case <-ctx.Done():
|
||||
t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err())
|
||||
cancelInput()
|
||||
_ = t.Close()
|
||||
<-waitCh
|
||||
waitInputPump()
|
||||
<-doneCopy
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Wait() error {
|
||||
info, err := t.WaitResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return info.CommandError()
|
||||
}
|
||||
|
||||
func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) {
|
||||
waitErr := t.waitRaw()
|
||||
info, closeErr := t.snapshotExitState()
|
||||
if closeErr != nil {
|
||||
return info, closeErr
|
||||
}
|
||||
if waitErr == nil {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(waitErr, &exitErr) {
|
||||
return info, nil
|
||||
}
|
||||
if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed {
|
||||
return info, nil
|
||||
}
|
||||
return info, waitErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) ExitInfo() TerminalExitInfo {
|
||||
if t == nil {
|
||||
return TerminalExitInfo{}
|
||||
}
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.exitInfo
|
||||
}
|
||||
|
||||
func (info TerminalExitInfo) Success() bool {
|
||||
return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == ""
|
||||
}
|
||||
|
||||
func (info TerminalExitInfo) CommandError() error {
|
||||
if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal {
|
||||
return nil
|
||||
}
|
||||
if info.ExitCode == 0 && info.ExitSignal == "" {
|
||||
return nil
|
||||
}
|
||||
return &ExecExitError{
|
||||
Status: info.ExitCode,
|
||||
Signal: info.ExitSignal,
|
||||
Message: info.ExitMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Resize(columns int, rows int) error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
if columns <= 0 || rows <= 0 {
|
||||
return errors.New("columns and rows must be > 0")
|
||||
}
|
||||
return t.Session.WindowChange(rows, columns)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) Close() error {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
t.closeOnce.Do(func() {
|
||||
if t.stdin != nil {
|
||||
_ = t.stdin.Close()
|
||||
}
|
||||
if t.Session != nil {
|
||||
closeErr = normalizeAlreadyClosedError(t.Session.Close())
|
||||
}
|
||||
})
|
||||
if closeErr != nil {
|
||||
t.markCloseReason(TerminalCloseReasonTransportError, closeErr)
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) waitRaw() error {
|
||||
if t == nil || t.Session == nil {
|
||||
return errors.New("terminal session is not initialized")
|
||||
}
|
||||
|
||||
t.waitOnce.Do(func() {
|
||||
go func() {
|
||||
waitErr := t.Session.Wait()
|
||||
t.setWaitResult(waitErr)
|
||||
close(t.waitDone)
|
||||
}()
|
||||
})
|
||||
|
||||
<-t.waitDone
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.waitErr
|
||||
}
|
||||
|
||||
func (t *TerminalSession) setWaitResult(waitErr error) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.stateMu.Lock()
|
||||
defer t.stateMu.Unlock()
|
||||
t.waitErr = waitErr
|
||||
t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason)
|
||||
}
|
||||
|
||||
func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) {
|
||||
if t == nil || reason == TerminalCloseReasonUnknown {
|
||||
return
|
||||
}
|
||||
|
||||
t.stateMu.Lock()
|
||||
defer t.stateMu.Unlock()
|
||||
|
||||
if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) {
|
||||
t.closeReason = reason
|
||||
}
|
||||
if err != nil && t.closeErr == nil {
|
||||
t.closeErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) {
|
||||
if t == nil {
|
||||
return TerminalExitInfo{}, nil
|
||||
}
|
||||
|
||||
t.stateMu.RLock()
|
||||
defer t.stateMu.RUnlock()
|
||||
return t.exitInfo, t.closeErr
|
||||
}
|
||||
|
||||
func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo {
|
||||
info := TerminalExitInfo{}
|
||||
|
||||
if waitErr == nil {
|
||||
info.Reason = TerminalCloseReasonExit
|
||||
} else {
|
||||
var exitErr *ssh.ExitError
|
||||
switch {
|
||||
case errors.As(waitErr, &exitErr):
|
||||
info.ExitCode = exitErr.ExitStatus()
|
||||
info.ExitSignal = exitErr.Signal()
|
||||
info.ExitMessage = exitErr.Msg()
|
||||
if info.ExitSignal != "" {
|
||||
info.Reason = TerminalCloseReasonSignal
|
||||
} else {
|
||||
info.Reason = TerminalCloseReasonExit
|
||||
}
|
||||
case normalizeAlreadyClosedError(waitErr) == nil:
|
||||
info.Reason = TerminalCloseReasonClosed
|
||||
case errors.Is(waitErr, context.Canceled):
|
||||
info.Reason = TerminalCloseReasonContextCanceled
|
||||
case errors.Is(waitErr, context.DeadlineExceeded):
|
||||
info.Reason = TerminalCloseReasonDeadlineExceeded
|
||||
default:
|
||||
info.Reason = TerminalCloseReasonTransportError
|
||||
}
|
||||
}
|
||||
|
||||
if overrideReason != TerminalCloseReasonUnknown &&
|
||||
terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) {
|
||||
info.Reason = overrideReason
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func terminalCloseReasonFromErr(err error) TerminalCloseReason {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return TerminalCloseReasonDeadlineExceeded
|
||||
case errors.Is(err, context.Canceled):
|
||||
return TerminalCloseReasonContextCanceled
|
||||
case err != nil:
|
||||
return TerminalCloseReasonTransportError
|
||||
default:
|
||||
return TerminalCloseReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func terminalCloseReasonPriority(reason TerminalCloseReason) int {
|
||||
switch reason {
|
||||
case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded:
|
||||
return 30
|
||||
case TerminalCloseReasonTransportError:
|
||||
return 20
|
||||
case TerminalCloseReasonClosed:
|
||||
return 10
|
||||
case TerminalCloseReasonSignal, TerminalCloseReasonExit:
|
||||
return 5
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
|
||||
// Implementations that buffer data should return a source that already includes any prefetched bytes.
|
||||
type TerminalInputSourceProvider interface {
|
||||
TerminalInputSource() io.Reader
|
||||
}
|
||||
|
||||
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
|
||||
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
|
||||
type TerminalInputCanceler interface {
|
||||
TerminalInputCancel() error
|
||||
}
|
||||
|
||||
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
|
||||
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
|
||||
type TerminalInputAdapter struct {
|
||||
Reader io.Reader
|
||||
Source io.Reader
|
||||
Cancel func() error
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
|
||||
if a.Reader == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return a.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
|
||||
if a.Source != nil {
|
||||
return a.Source
|
||||
}
|
||||
return a.Reader
|
||||
}
|
||||
|
||||
func (a TerminalInputAdapter) TerminalInputCancel() error {
|
||||
if a.Cancel != nil {
|
||||
return a.Cancel()
|
||||
}
|
||||
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
|
||||
if in == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
var cancelOnce sync.Once
|
||||
wrapCancel := func(fn func()) func() {
|
||||
return func() {
|
||||
cancelOnce.Do(fn)
|
||||
}
|
||||
}
|
||||
|
||||
if provider, ok := in.(TerminalInputSourceProvider); ok {
|
||||
source := provider.TerminalInputSource()
|
||||
if source == nil || sameReader(source, in) {
|
||||
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||||
return prepared, wrapCancel(func() {
|
||||
cancel()
|
||||
_ = canceler.TerminalInputCancel()
|
||||
}), true, nil
|
||||
}
|
||||
return prepared, cancel, cancelable, nil
|
||||
}
|
||||
|
||||
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||||
}
|
||||
|
||||
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
|
||||
if in == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
switch typed := in.(type) {
|
||||
case *bufio.Reader:
|
||||
return prepareBufferedTerminalInputReader(typed)
|
||||
case *bufio.ReadWriter:
|
||||
if typed.Reader == nil {
|
||||
return in, func() {}, false, nil
|
||||
}
|
||||
return prepareBufferedTerminalInputReader(typed.Reader)
|
||||
}
|
||||
|
||||
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||||
return in, wrapCancel(func() {
|
||||
_ = canceler.TerminalInputCancel()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
if file, ok := in.(*os.File); ok {
|
||||
dup, err := duplicateTerminalInputFile(file)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
|
||||
}
|
||||
return dup, wrapCancel(func() {
|
||||
_ = dup.Close()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
if closer, ok := in.(io.ReadCloser); ok {
|
||||
return closer, wrapCancel(func() {
|
||||
_ = closer.Close()
|
||||
}), true, nil
|
||||
}
|
||||
|
||||
return in, func() {}, false, nil
|
||||
}
|
||||
|
||||
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
|
||||
if reader == nil {
|
||||
return nil, func() {}, false, nil
|
||||
}
|
||||
|
||||
bufferedPrefix, err := snapshotBufferedPrefix(reader)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
underlying := unwrapBufioReader(reader)
|
||||
if underlying == nil {
|
||||
if len(bufferedPrefix) == 0 {
|
||||
return reader, func() {}, false, nil
|
||||
}
|
||||
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
if len(bufferedPrefix) == 0 {
|
||||
return prepared, cancel, cancelable, nil
|
||||
}
|
||||
if prepared == nil {
|
||||
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
|
||||
}
|
||||
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
|
||||
}
|
||||
|
||||
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buffered := reader.Buffered()
|
||||
if buffered == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chunk, err := reader.Peek(buffered)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
|
||||
}
|
||||
prefix := append([]byte(nil), chunk...)
|
||||
if _, err := reader.Discard(len(prefix)); err != nil {
|
||||
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
|
||||
}
|
||||
return prefix, nil
|
||||
}
|
||||
|
||||
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
|
||||
if reader == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(reader)
|
||||
if value.Kind() != reflect.Pointer || value.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
field := value.Elem().FieldByName("rd")
|
||||
if !field.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
|
||||
underlying, ok := underlyingValue.Interface().(io.Reader)
|
||||
if !ok || underlying == nil || sameReader(underlying, reader) {
|
||||
return nil
|
||||
}
|
||||
return underlying
|
||||
}
|
||||
|
||||
func sameReader(left io.Reader, right io.Reader) bool {
|
||||
if left == nil || right == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(left)
|
||||
rightValue := reflect.ValueOf(right)
|
||||
if !leftValue.IsValid() || !rightValue.IsValid() {
|
||||
return false
|
||||
}
|
||||
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
|
||||
return false
|
||||
}
|
||||
return leftValue.Pointer() == rightValue.Pointer()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPrepareTerminalInputReaderAdapterCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
adapter := TerminalInputAdapter{
|
||||
Reader: bufio.NewReader(reader),
|
||||
Source: reader,
|
||||
Cancel: func() error {
|
||||
return reader.CloseWithError(io.ErrClosedPipe)
|
||||
},
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(adapter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare adapter input: %v", err)
|
||||
}
|
||||
if !cancelable {
|
||||
t.Fatal("expected adapter-backed input to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected adapter cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking adapter input did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type terminalInputProvider struct {
|
||||
io.Reader
|
||||
source io.Reader
|
||||
}
|
||||
|
||||
func (p terminalInputProvider) TerminalInputSource() io.Reader {
|
||||
return p.source
|
||||
}
|
||||
|
||||
type prefixedReadCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderPreservesBufferedBytes(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("hello world")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
buffered := bufio.NewReaderSize(reader, 4)
|
||||
peeked, err := buffered.Peek(5)
|
||||
if err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
if string(peeked) != "hello" {
|
||||
t.Fatalf("unexpected buffered prefix: %q", string(peeked))
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected cancelable reader")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read prepared input: %v", err)
|
||||
}
|
||||
if string(data) != "hello world" {
|
||||
t.Fatalf("unexpected prepared input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReadWriterPreservesBufferedBytes(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("buffered payload")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
readWriter := bufio.NewReadWriter(bufio.NewReaderSize(reader, 8), bufio.NewWriter(io.Discard))
|
||||
if _, err := readWriter.Reader.Peek(8); err != nil {
|
||||
t.Fatalf("prime readwriter buffer: %v", err)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare readwriter input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected cancelable readwriter input")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read prepared readwriter input: %v", err)
|
||||
}
|
||||
if string(data) != "buffered payload" {
|
||||
t.Fatalf("unexpected prepared readwriter input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderFallbackKeepsData(t *testing.T) {
|
||||
buffered := bufio.NewReader(bytes.NewBufferString("abc123"))
|
||||
if _, err := buffered.Peek(3); err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare fallback input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if cancelable {
|
||||
t.Fatal("expected non-cancelable reader")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read fallback input: %v", err)
|
||||
}
|
||||
if string(data) != "abc123" {
|
||||
t.Fatalf("unexpected fallback input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderProviderPrefersExplicitSource(t *testing.T) {
|
||||
reader, writer, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if _, err := writer.Write([]byte("provider data")); err != nil {
|
||||
writer.Close()
|
||||
t.Fatalf("write pipe: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close writer: %v", err)
|
||||
}
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
prefix, err := buffered.Peek(len("provider data"))
|
||||
if err != nil {
|
||||
t.Fatalf("prime buffer: %v", err)
|
||||
}
|
||||
|
||||
source := prefixedReadCloser{
|
||||
Reader: io.MultiReader(bytes.NewReader(append([]byte(nil), prefix...)), reader),
|
||||
Closer: reader,
|
||||
}
|
||||
provider := terminalInputProvider{
|
||||
Reader: buffered,
|
||||
source: source,
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare provider input: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected provider-backed input to be cancelable")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(prepared)
|
||||
if err != nil {
|
||||
t.Fatalf("read provider input: %v", err)
|
||||
}
|
||||
if string(data) != "provider data" {
|
||||
t.Fatalf("unexpected provider input: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderProviderCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
provider := terminalInputProvider{
|
||||
Reader: buffered,
|
||||
source: reader,
|
||||
}
|
||||
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected provider-backed input to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking read did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReaderCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
buffered := bufio.NewReader(reader)
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(buffered)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected bufio reader to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking bufio reader did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareTerminalInputReaderBufioReadWriterCancelUnblocksRead(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
defer reader.Close()
|
||||
defer writer.Close()
|
||||
|
||||
readWriter := bufio.NewReadWriter(bufio.NewReader(reader), bufio.NewWriter(io.Discard))
|
||||
prepared, cancel, cancelable, err := prepareTerminalInputReader(readWriter)
|
||||
if err != nil {
|
||||
t.Fatalf("prepare readwriter input: %v", err)
|
||||
}
|
||||
|
||||
if !cancelable {
|
||||
t.Fatal("expected bufio readwriter to be cancelable")
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := prepared.Read(buf)
|
||||
done <- readErr
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case readErr := <-done:
|
||||
if readErr == nil {
|
||||
t.Fatal("expected cancel to interrupt blocking read")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("blocking readwriter input did not unblock after cancel")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
//go:build !windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
|
||||
if file == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
fd, err := syscall.Dup(int(file.Fd()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
syscall.CloseOnExec(fd)
|
||||
return os.NewFile(uintptr(fd), file.Name()), nil
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
//go:build windows
|
||||
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func duplicateTerminalInputFile(file *os.File) (*os.File, error) {
|
||||
if file == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
currentProcess := windows.CurrentProcess()
|
||||
var duplicated windows.Handle
|
||||
err := windows.DuplicateHandle(
|
||||
currentProcess,
|
||||
windows.Handle(file.Fd()),
|
||||
currentProcess,
|
||||
&duplicated,
|
||||
0,
|
||||
false,
|
||||
windows.DUPLICATE_SAME_ACCESS,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(duplicated), file.Name()), nil
|
||||
}
|
||||
+336
@@ -0,0 +1,336 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c == nil || c.reader == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func resolveDialContext(info LoginInput) DialContextFunc {
|
||||
if info.DialContext != nil {
|
||||
return info.DialContext
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: effectiveDialTimeout(info),
|
||||
}
|
||||
return dialer.DialContext
|
||||
}
|
||||
|
||||
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
|
||||
targetAddr := joinHostPort(info.Addr, info.Port)
|
||||
if info.Jump != nil {
|
||||
return dialViaJump(ctx, info, targetAddr)
|
||||
}
|
||||
|
||||
dialContext := resolveDialContext(info)
|
||||
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||
if proxyConfig != nil {
|
||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||
}
|
||||
|
||||
conn, err := dialContext(ctx, "tcp", targetAddr)
|
||||
return conn, nil, err
|
||||
}
|
||||
|
||||
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
|
||||
if info.Jump == nil {
|
||||
return nil, nil, errors.New("jump login info is nil")
|
||||
}
|
||||
|
||||
jumpClient, err := loginWithContext(ctx, *info.Jump)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
|
||||
if err != nil {
|
||||
_ = jumpClient.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, jumpClient, nil
|
||||
}
|
||||
|
||||
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
|
||||
if dialContext == nil {
|
||||
return nil, nil, errors.New("dial context is nil")
|
||||
}
|
||||
if strings.TrimSpace(proxy.Addr) == "" {
|
||||
return nil, nil, errors.New("proxy address is empty")
|
||||
}
|
||||
|
||||
switch proxy.Type {
|
||||
case ProxyTypeSOCKS5:
|
||||
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
|
||||
return conn, nil, err
|
||||
case ProxyTypeHTTPConnect:
|
||||
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
|
||||
return conn, nil, err
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||||
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
|
||||
if proxy.Username != "" || proxy.Password != "" {
|
||||
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
|
||||
request += "Proxy-Authorization: Basic " + token + "\r\n"
|
||||
}
|
||||
request += "\r\n"
|
||||
|
||||
if _, err := io.WriteString(conn, request); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
|
||||
}
|
||||
|
||||
if reader.Buffered() == 0 {
|
||||
return conn, nil
|
||||
}
|
||||
return &bufferedConn{Conn: conn, reader: reader}, nil
|
||||
}
|
||||
|
||||
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||||
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
methods := []byte{0x00}
|
||||
useAuth := proxy.Username != "" || proxy.Password != ""
|
||||
if useAuth {
|
||||
methods = append(methods, 0x02)
|
||||
}
|
||||
|
||||
hello := append([]byte{0x05, byte(len(methods))}, methods...)
|
||||
if _, err := conn.Write(hello); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, response); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if response[0] != 0x05 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
|
||||
}
|
||||
if response[1] == 0xFF {
|
||||
_ = conn.Close()
|
||||
return nil, errors.New("socks5 proxy has no acceptable auth method")
|
||||
}
|
||||
|
||||
if response[1] == 0x02 {
|
||||
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := readSOCKS5ConnectResponse(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
|
||||
if len(username) > 255 || len(password) > 255 {
|
||||
return errors.New("socks5 username/password too long")
|
||||
}
|
||||
|
||||
request := make([]byte, 0, 3+len(username)+len(password))
|
||||
request = append(request, 0x01, byte(len(username)))
|
||||
request = append(request, []byte(username)...)
|
||||
request = append(request, byte(len(password)))
|
||||
request = append(request, []byte(password)...)
|
||||
|
||||
if _, err := conn.Write(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, response); err != nil {
|
||||
return err
|
||||
}
|
||||
if response[1] != 0x00 {
|
||||
return errors.New("socks5 username/password authentication failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
|
||||
host, portString, err := net.SplitHostPort(targetAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid port %d", port)
|
||||
}
|
||||
|
||||
request := []byte{0x05, 0x01, 0x00}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
request = append(request, 0x01)
|
||||
request = append(request, ip4...)
|
||||
} else {
|
||||
request = append(request, 0x04)
|
||||
request = append(request, ip.To16()...)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return errors.New("socks5 target host too long")
|
||||
}
|
||||
request = append(request, 0x03, byte(len(host)))
|
||||
request = append(request, []byte(host)...)
|
||||
}
|
||||
|
||||
request = append(request, byte(port>>8), byte(port))
|
||||
_, err = conn.Write(request)
|
||||
return err
|
||||
}
|
||||
|
||||
func readSOCKS5ConnectResponse(conn net.Conn) error {
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != 0x05 {
|
||||
return fmt.Errorf("invalid socks5 response version %d", header[0])
|
||||
}
|
||||
if header[1] != 0x00 {
|
||||
return fmt.Errorf("socks5 connect failed with code %d", header[1])
|
||||
}
|
||||
|
||||
switch header[3] {
|
||||
case 0x01:
|
||||
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
|
||||
return err
|
||||
}
|
||||
case 0x03:
|
||||
size := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, size); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
|
||||
return err
|
||||
}
|
||||
case 0x04:
|
||||
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
|
||||
}
|
||||
|
||||
_, err := io.ReadFull(conn, make([]byte, 2))
|
||||
return err
|
||||
}
|
||||
|
||||
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
|
||||
if conn == nil {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
var (
|
||||
deadline time.Time
|
||||
hasValue bool
|
||||
)
|
||||
if ctx != nil {
|
||||
if ctxDeadline, ok := ctx.Deadline(); ok {
|
||||
deadline = ctxDeadline
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
if timeout > 0 {
|
||||
timeoutDeadline := time.Now().Add(timeout)
|
||||
if !hasValue || timeoutDeadline.Before(deadline) {
|
||||
deadline = timeoutDeadline
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
if !hasValue {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
_ = conn.SetDeadline(deadline)
|
||||
return func() {
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
|
||||
if proxy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := *proxy
|
||||
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
|
||||
normalized.Addr = strings.TrimSpace(normalized.Addr)
|
||||
if normalized.Timeout <= 0 {
|
||||
normalized.Timeout = defaultTimeout
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
func joinHostPort(host string, port int) string {
|
||||
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSSHPort = 22
|
||||
defaultLoginTimeout = 5 * time.Second
|
||||
defaultSSHAgentTimeout = 2 * time.Minute
|
||||
defaultKeepAliveTimeout = 3 * time.Second
|
||||
defaultShellPollInterval = 120 * time.Millisecond
|
||||
defaultShellSetupDelay = 200 * time.Millisecond
|
||||
defaultShellSetupTimeout = 3 * time.Second
|
||||
defaultShellWaitTimeout = 30 * time.Second
|
||||
defaultShellPromptToken = "__STARSSH_PROMPT__>"
|
||||
defaultPTYTerm = "xterm"
|
||||
defaultPTYRows = 500
|
||||
defaultPTYColumns = 250
|
||||
|
||||
defaultTransferBufferSize = 1024 * 1024
|
||||
|
||||
defaultExecStreamMaxPendingChunks = 256
|
||||
defaultExecStreamMaxPendingBytes = 4 * 1024 * 1024
|
||||
)
|
||||
|
||||
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
type ProxyType string
|
||||
|
||||
const (
|
||||
ProxyTypeSOCKS5 ProxyType = "socks5"
|
||||
ProxyTypeHTTPConnect ProxyType = "http_connect"
|
||||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Type ProxyType
|
||||
Addr string
|
||||
Username string
|
||||
Password string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type AuthMethodKind string
|
||||
|
||||
const (
|
||||
AuthMethodPrivateKey AuthMethodKind = "private_key"
|
||||
AuthMethodPassword AuthMethodKind = "password"
|
||||
AuthMethodKeyboardInteractive AuthMethodKind = "keyboard_interactive"
|
||||
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
|
||||
)
|
||||
|
||||
type SSHAgentDebugFunc func(SSHAgentDebugEvent)
|
||||
|
||||
type SSHAgentDebugEvent struct {
|
||||
Step string
|
||||
Source string
|
||||
Endpoint string
|
||||
Network string
|
||||
Phase string
|
||||
Status string
|
||||
Duration time.Duration
|
||||
KeyCount int
|
||||
Err error
|
||||
}
|
||||
|
||||
type StarSSH struct {
|
||||
stateMu sync.RWMutex
|
||||
Client *ssh.Client
|
||||
PublicKey ssh.PublicKey
|
||||
PubkeyBase64 string
|
||||
Hostname string
|
||||
RemoteAddr net.Addr
|
||||
Banner string
|
||||
LoginInfo LoginInput
|
||||
online bool
|
||||
upstream *StarSSH
|
||||
sftpClient *sftp.Client
|
||||
sftpMu sync.Mutex
|
||||
agentForwardMu sync.Mutex
|
||||
agentForwarder io.Closer
|
||||
keepaliveMu sync.Mutex
|
||||
keepaliveStop chan struct{}
|
||||
keepaliveDone chan struct{}
|
||||
closing atomic.Bool
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
KeyExchanges []string
|
||||
Ciphers []string
|
||||
MACs []string
|
||||
User string
|
||||
Password string
|
||||
PasswordCallback func() (string, error)
|
||||
KeyboardInteractiveCallback ssh.KeyboardInteractiveChallenge
|
||||
Prikey string
|
||||
PrikeyPwd string
|
||||
DisableSSHAgent bool
|
||||
ForwardSSHAgent bool
|
||||
AuthOrder []AuthMethodKind
|
||||
// IdentityAgent overrides the local ssh-agent endpoint used for authentication
|
||||
// and agent forwarding. Empty uses SSH_AUTH_SOCK, or the platform default where
|
||||
// one exists.
|
||||
IdentityAgent string
|
||||
Addr string
|
||||
Port int
|
||||
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||
// already been established. Zero means no authentication timeout.
|
||||
Timeout time.Duration
|
||||
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||
DialTimeout time.Duration
|
||||
// SSHAgentTimeout limits ssh-agent protocol operations such as listing keys and
|
||||
// signing challenges. Zero uses the package default, and negative disables the
|
||||
// per-operation deadline. This is intentionally separate from Timeout and
|
||||
// DialTimeout because hardware-backed agents may require a PIN or touch confirmation.
|
||||
SSHAgentTimeout time.Duration
|
||||
// SSHAgentForwardTimeout limits idle reads and writes on forwarded agent
|
||||
// channels. Zero or negative leaves forwarded channels without an idle deadline.
|
||||
SSHAgentForwardTimeout time.Duration
|
||||
// SSHAgentDebug receives structured ssh-agent dial/protocol events. It is nil by
|
||||
// default and must not log private key material.
|
||||
SSHAgentDebug SSHAgentDebugFunc
|
||||
DialContext DialContextFunc
|
||||
Proxy *ProxyConfig
|
||||
Jump *LoginInput
|
||||
KeepAliveInterval time.Duration
|
||||
KeepAliveTimeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
}
|
||||
|
||||
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
||||
// It is not a generic cross-shell abstraction; for product-grade interactive terminals, prefer TerminalSession.
|
||||
type StarShell struct {
|
||||
Keyword string
|
||||
UseWaitDefault bool
|
||||
WaitTimeout time.Duration
|
||||
Session *ssh.Session
|
||||
in io.Writer
|
||||
out *bufio.Reader
|
||||
er *bufio.Reader
|
||||
outbyte []byte
|
||||
errbyte []byte
|
||||
lastout int64
|
||||
errors error
|
||||
isprint bool
|
||||
isfuncs bool
|
||||
iscolor bool
|
||||
isecho bool
|
||||
rw sync.RWMutex
|
||||
funcs func(string)
|
||||
writeMu sync.Mutex
|
||||
commandMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
promptToken string
|
||||
}
|
||||
|
||||
type TerminalConfig struct {
|
||||
Term string
|
||||
Rows int
|
||||
Columns int
|
||||
Modes ssh.TerminalModes
|
||||
}
|
||||
|
||||
type TerminalControl byte
|
||||
|
||||
const (
|
||||
TerminalControlInterrupt TerminalControl = 0x03
|
||||
TerminalControlEOF TerminalControl = 0x04
|
||||
TerminalControlBell TerminalControl = 0x07
|
||||
TerminalControlBackspace TerminalControl = 0x08
|
||||
TerminalControlLineKill TerminalControl = 0x15
|
||||
TerminalControlQuit TerminalControl = 0x1c
|
||||
TerminalControlSuspend TerminalControl = 0x1a
|
||||
TerminalControlPauseOutput TerminalControl = 0x13
|
||||
TerminalControlResumeOutput TerminalControl = 0x11
|
||||
)
|
||||
|
||||
type TerminalCloseReason string
|
||||
|
||||
const (
|
||||
TerminalCloseReasonUnknown TerminalCloseReason = ""
|
||||
TerminalCloseReasonExit TerminalCloseReason = "exit"
|
||||
TerminalCloseReasonSignal TerminalCloseReason = "signal"
|
||||
TerminalCloseReasonClosed TerminalCloseReason = "closed"
|
||||
TerminalCloseReasonContextCanceled TerminalCloseReason = "context_canceled"
|
||||
TerminalCloseReasonDeadlineExceeded TerminalCloseReason = "deadline_exceeded"
|
||||
TerminalCloseReasonTransportError TerminalCloseReason = "transport_error"
|
||||
)
|
||||
|
||||
type TerminalExitInfo struct {
|
||||
ExitCode int
|
||||
ExitSignal string
|
||||
ExitMessage string
|
||||
Reason TerminalCloseReason
|
||||
}
|
||||
|
||||
type TerminalSession struct {
|
||||
Session *ssh.Session
|
||||
ID string
|
||||
Label string
|
||||
Metadata map[string]string
|
||||
stdin io.WriteCloser
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
|
||||
attachMu sync.RWMutex
|
||||
in io.Reader
|
||||
out io.Writer
|
||||
errOut io.Writer
|
||||
|
||||
runOnce sync.Once
|
||||
runDone chan struct{}
|
||||
runErr error
|
||||
|
||||
waitOnce sync.Once
|
||||
waitDone chan struct{}
|
||||
|
||||
stateMu sync.RWMutex
|
||||
waitErr error
|
||||
exitInfo TerminalExitInfo
|
||||
closeReason TerminalCloseReason
|
||||
closeErr error
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ansiCSIRegexp = regexp.MustCompile(`\x1b\[[0-9;?]*[ -/]*[@-~]`)
|
||||
ansiOSCRegexp = regexp.MustCompile(`\x1b\][^\x07]*(\x07|\x1b\\)`)
|
||||
leadingIntRegexp = regexp.MustCompile(`^[+-]?\d+`)
|
||||
)
|
||||
|
||||
func SedColor(str string) string {
|
||||
return stripControlSequences(str)
|
||||
}
|
||||
|
||||
func normalizeShellOutput(raw string) string {
|
||||
return strings.TrimSpace(strings.ReplaceAll(raw, "\r\n", "\n"))
|
||||
}
|
||||
|
||||
func stripControlSequences(raw string) string {
|
||||
cleaned := ansiOSCRegexp.ReplaceAllString(raw, "")
|
||||
cleaned = ansiCSIRegexp.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "\r", "")
|
||||
cleaned = strings.Map(func(r rune) rune {
|
||||
if r == '\n' || r == '\t' {
|
||||
return r
|
||||
}
|
||||
if r < 0x20 || r == 0x7f {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func stripLeadingEcho(output string, command string, markerCommand string) string {
|
||||
result := output
|
||||
if command == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
withMarker := command
|
||||
if markerCommand != "" {
|
||||
withMarker += "\n" + markerCommand
|
||||
}
|
||||
if strings.HasPrefix(result, withMarker) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(result, withMarker))
|
||||
}
|
||||
if strings.HasPrefix(result, command) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(result, command))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectLinesWithoutTokens(output string, tokens ...string) string {
|
||||
lines := strings.Split(output, "\n")
|
||||
filtered := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
skip := false
|
||||
for _, token := range tokens {
|
||||
if token != "" && strings.Contains(line, token) {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(filtered, "\n"))
|
||||
}
|
||||
|
||||
func collectLinesForCommandOutput(output string, promptToken string, tokens ...string) string {
|
||||
lines := strings.Split(output, "\n")
|
||||
filtered := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
skip := false
|
||||
for _, token := range tokens {
|
||||
if token != "" && strings.Contains(line, token) {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip && promptToken != "" && strings.Contains(line, promptToken) {
|
||||
skip = true
|
||||
}
|
||||
if !skip {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(filtered, "\n"))
|
||||
}
|
||||
|
||||
func shellSingleQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
|
||||
func normalizeBufferSize(bufcap int) int {
|
||||
if bufcap <= 0 {
|
||||
return defaultTransferBufferSize
|
||||
}
|
||||
return bufcap
|
||||
}
|
||||
|
||||
func newCommandTokens() (beginToken string, endToken string) {
|
||||
nonce := newNonce(8)
|
||||
return "__STARSSH_BEGIN_" + nonce + "__", "__STARSSH_END_" + nonce + "__"
|
||||
}
|
||||
|
||||
func newNonce(size int) string {
|
||||
if size <= 0 {
|
||||
size = 8
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(buf))
|
||||
}
|
||||
|
||||
func splitByEndToken(output string, endToken string) (before string, exitCode int, found bool, parseErr error) {
|
||||
prefix := endToken + ":"
|
||||
lines := strings.Split(output, "\n")
|
||||
beforeLines := make([]string, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(trimmedLine, prefix) {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
codeText := strings.TrimSpace(trimmedLine[len(prefix):])
|
||||
match := leadingIntRegexp.FindString(codeText)
|
||||
if match == "" || match != codeText {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
code, err := strconv.Atoi(match)
|
||||
if err != nil {
|
||||
beforeLines = append(beforeLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
return strings.Join(beforeLines, "\n"), code, true, nil
|
||||
}
|
||||
|
||||
return strings.Join(beforeLines, "\n"), 0, false, nil
|
||||
}
|
||||
Reference in New Issue
Block a user