starssh/exec.go

1219 lines
28 KiB
Go
Raw Permalink Normal View History

refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
package starssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode"
"golang.org/x/crypto/ssh"
)
type ExecRequest struct {
Command string
Stdin []byte
Env map[string]string
Dir string
ShellDialect ExecShellDialect
Timeout time.Duration
PTY *TerminalConfig
DiscardOutput bool
MaxOutputBytes int
StreamMaxPendingChunks int
StreamMaxPendingBytes int
StreamOverflowStrategy ExecStreamOverflowStrategy
}
type ExecResult struct {
Command string
Stdout []byte
Stderr []byte
Combined []byte
StdoutTruncated bool
StderrTruncated bool
CombinedTruncated bool
StreamDroppedChunks int
StreamDroppedBytes int
ExitCode int
ExitSignal string
ExitMessage string
Duration time.Duration
}
type ExecStreamChunk struct {
Data []byte
Stderr bool
}
type ExecShellDialect string
const (
ExecShellDialectPOSIX ExecShellDialect = "posix"
ExecShellDialectPowerShell ExecShellDialect = "powershell"
ExecShellDialectCMD ExecShellDialect = "cmd"
ExecShellDialectRaw ExecShellDialect = "raw"
)
type ExecStreamOverflowStrategy string
const (
ExecStreamOverflowDropOldest ExecStreamOverflowStrategy = "drop_oldest"
ExecStreamOverflowDropNewest ExecStreamOverflowStrategy = "drop_newest"
ExecStreamOverflowFail ExecStreamOverflowStrategy = "fail"
)
type ExecExitError struct {
Status int
Signal string
Message string
Stderr string
}
type ExecStreamOverflowError struct {
DroppedChunks int
DroppedBytes int
}
type ShellExitError = ExecExitError
var execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) {
return s.Exec(ctx, req)
}
func (e *ExecExitError) Error() string {
if e == nil {
return ""
}
base := "remote command exited"
if e.Status != 0 {
base += " with status " + strconv.Itoa(e.Status)
}
if e.Signal != "" {
base += " from signal " + e.Signal
}
if e.Message != "" {
base += ": " + e.Message
}
if e.Stderr != "" {
base += ": " + e.Stderr
}
return base
}
func (e *ExecExitError) ExitStatus() int {
if e == nil {
return 0
}
return e.Status
}
func (e *ExecStreamOverflowError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("exec stream callback queue overflow: dropped %d chunks (%d bytes)", e.DroppedChunks, e.DroppedBytes)
}
func (r *ExecResult) Success() bool {
return r != nil && r.ExitCode == 0 && r.ExitSignal == ""
}
func (r *ExecResult) StdoutString() string {
if r == nil {
return ""
}
return string(r.Stdout)
}
func (r *ExecResult) StderrString() string {
if r == nil {
return ""
}
return string(r.Stderr)
}
func (r *ExecResult) CombinedString() string {
if r == nil {
return ""
}
return string(r.Combined)
}
func (r *ExecResult) CommandError() error {
if r == nil || r.Success() {
return nil
}
return &ExecExitError{
Status: r.ExitCode,
Signal: r.ExitSignal,
Message: strings.TrimSpace(r.ExitMessage),
Stderr: strings.TrimSpace(r.StderrString()),
}
}
func (r *ExecResult) OutputTruncated() bool {
if r == nil {
return false
}
return r.StdoutTruncated || r.StderrTruncated || r.CombinedTruncated
}
func (r *ExecResult) StreamOutputDropped() bool {
return r != nil && (r.StreamDroppedChunks > 0 || r.StreamDroppedBytes > 0)
}
func (r *ExecResult) StreamOverflowError() error {
if r == nil || !r.StreamOutputDropped() {
return nil
}
return &ExecStreamOverflowError{
DroppedChunks: r.StreamDroppedChunks,
DroppedBytes: r.StreamDroppedBytes,
}
}
func (s *StarSSH) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) {
return s.exec(ctx, req, nil)
}
func (s *StarSSH) ExecString(ctx context.Context, command string) (*ExecResult, error) {
return s.Exec(ctx, ExecRequest{
Command: command,
})
}
func (s *StarSSH) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
return s.exec(ctx, req, onChunk)
}
func (s *StarSSH) exec(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) {
if s == nil {
return nil, errors.New("ssh client is nil")
}
if ctx == nil {
ctx = context.Background()
}
if req.Timeout > 0 {
timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout)
defer cancel()
ctx = timeoutCtx
}
remoteCommand, err := buildExecCommand(req)
if err != nil {
return nil, err
}
session, err := s.newExecRuntimeSession(req)
if err != nil {
return nil, err
}
defer session.Close()
var stdin io.WriteCloser
if req.Stdin != nil {
stdin, err = session.StdinPipe()
if err != nil {
return nil, err
}
}
stdout, err := session.StdoutPipe()
if err != nil {
return nil, err
}
stderr, err := session.StderrPipe()
if err != nil {
return nil, err
}
result := &ExecResult{
Command: req.Command,
}
startAt := time.Now()
if err := session.Start(remoteCommand); err != nil {
return nil, err
}
if stdin != nil {
go func() {
_, _ = stdin.Write(req.Stdin)
_ = stdin.Close()
}()
}
chunks := make(chan ExecStreamChunk, 16)
readErrs := make(chan error, 2)
var readWG sync.WaitGroup
readWG.Add(2)
go streamExecReader(stdout, false, chunks, readErrs, &readWG)
go streamExecReader(stderr, true, chunks, readErrs, &readWG)
go func() {
readWG.Wait()
close(chunks)
close(readErrs)
}()
stdoutBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput)
stderrBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput)
combinedBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput)
dispatcher, err := newExecChunkDispatcher(req, onChunk)
if err != nil {
return nil, err
}
drainDone := make(chan struct{})
go func() {
defer close(drainDone)
defer dispatcher.Close()
for chunk := range chunks {
if len(chunk.Data) == 0 {
continue
}
_, _ = combinedBuf.Write(chunk.Data)
if chunk.Stderr {
_, _ = stderrBuf.Write(chunk.Data)
} else {
_, _ = stdoutBuf.Write(chunk.Data)
}
dispatcher.Enqueue(chunk)
}
}()
waitCh := make(chan error, 1)
go func() {
waitCh <- session.Wait()
}()
var waitErr error
select {
case waitErr = <-waitCh:
case <-ctx.Done():
_ = session.Close()
waitErr = ctx.Err()
}
<-drainDone
dispatchStats := dispatcher.Wait()
readErr := firstExecError(readErrs)
result.Stdout = append(result.Stdout[:0], stdoutBuf.Bytes()...)
result.Stderr = append(result.Stderr[:0], stderrBuf.Bytes()...)
result.Combined = append(result.Combined[:0], combinedBuf.Bytes()...)
result.StdoutTruncated = stdoutBuf.Truncated()
result.StderrTruncated = stderrBuf.Truncated()
result.CombinedTruncated = combinedBuf.Truncated()
result.StreamDroppedChunks = dispatchStats.droppedChunks
result.StreamDroppedBytes = dispatchStats.droppedBytes
result.Duration = time.Since(startAt)
if errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded) {
return result, waitErr
}
var exitErr *ssh.ExitError
if errors.As(waitErr, &exitErr) {
result.ExitCode = exitErr.ExitStatus()
result.ExitSignal = exitErr.Signal()
result.ExitMessage = exitErr.Msg()
waitErr = nil
}
if readErr != nil {
return result, readErr
}
if dispatchStats.err != nil {
return result, dispatchStats.err
}
if waitErr == nil {
return result, nil
}
return result, waitErr
}
func (s *StarSSH) newExecRuntimeSession(req ExecRequest) (*ssh.Session, error) {
if req.PTY != nil {
return s.NewPTYSession(req.PTY)
}
return s.NewExecSession()
}
func buildExecCommand(req ExecRequest) (string, error) {
if strings.TrimSpace(req.Command) == "" {
return "", errors.New("command is empty")
}
dialect, err := normalizeExecShellDialect(req.ShellDialect)
if err != nil {
return "", err
}
switch dialect {
case ExecShellDialectPOSIX:
return buildExecCommandPOSIX(req)
case ExecShellDialectPowerShell:
return buildExecCommandPowerShell(req)
case ExecShellDialectCMD:
return buildExecCommandCMD(req)
case ExecShellDialectRaw:
return buildExecCommandRaw(req)
default:
return "", fmt.Errorf("unsupported exec shell dialect %q", req.ShellDialect)
}
}
func normalizeExecShellDialect(dialect ExecShellDialect) (ExecShellDialect, error) {
if strings.TrimSpace(string(dialect)) == "" {
return ExecShellDialectPOSIX, nil
}
switch ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))) {
case ExecShellDialectPOSIX, ExecShellDialectPowerShell, ExecShellDialectCMD, ExecShellDialectRaw:
return ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))), nil
default:
return "", fmt.Errorf("invalid exec shell dialect %q", dialect)
}
}
func buildExecCommandPOSIX(req ExecRequest) (string, error) {
parts := make([]string, 0, 3)
if strings.TrimSpace(req.Dir) != "" {
parts = append(parts, "cd "+shellSingleQuote(req.Dir))
}
if len(req.Env) > 0 {
keys := make([]string, 0, len(req.Env))
for key := range req.Env {
if !isValidShellEnvKey(key) {
return "", fmt.Errorf("invalid env key %q", key)
}
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]string, 0, len(keys))
for _, key := range keys {
assignments = append(assignments, key+"="+shellSingleQuote(req.Env[key]))
}
parts = append(parts, "export "+strings.Join(assignments, " "))
}
parts = append(parts, req.Command)
return strings.Join(parts, " && "), nil
}
func buildExecCommandPowerShell(req ExecRequest) (string, error) {
parts := make([]string, 0, 2+len(req.Env))
parts = append(parts, "$ErrorActionPreference = 'Stop'")
if strings.TrimSpace(req.Dir) != "" {
parts = append(parts, "Set-Location -LiteralPath "+powerShellSingleQuote(req.Dir))
}
if len(req.Env) > 0 {
keys := make([]string, 0, len(req.Env))
for key := range req.Env {
if !isValidShellEnvKey(key) {
return "", fmt.Errorf("invalid env key %q", key)
}
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, "$env:"+key+" = "+powerShellSingleQuote(req.Env[key]))
}
}
parts = append(parts, req.Command)
return strings.Join(parts, "; "), nil
}
func buildExecCommandCMD(req ExecRequest) (string, error) {
keys := make([]string, 0, len(req.Env))
for key := range req.Env {
if !isValidShellEnvKey(key) {
return "", fmt.Errorf("invalid env key %q", key)
}
keys = append(keys, key)
}
sort.Strings(keys)
replacements := make(map[string]string, len(keys)+1)
if strings.TrimSpace(req.Dir) != "" {
replacements["CD"] = "!CD!"
}
for _, key := range keys {
replacements[strings.ToUpper(key)] = "!" + key + "!"
}
command, rewrotePercentVars := rewriteCMDPercentVariables(req.Command, replacements)
needsDelayedExpansion := rewrotePercentVars
if strings.TrimSpace(req.Dir) != "" && cmdContainsBangVariable(command, "CD") {
needsDelayedExpansion = true
}
for _, key := range keys {
if cmdContainsBangVariable(command, key) {
needsDelayedExpansion = true
}
}
parts := make([]string, 0, 3+len(keys))
if len(keys) > 0 {
parts = append(parts, "setlocal DisableDelayedExpansion")
for _, key := range keys {
parts = append(parts, "set "+key+"="+cmdEscapeForSetValue(req.Env[key]))
}
}
if strings.TrimSpace(req.Dir) != "" {
parts = append(parts, "cd /d "+cmdEscapeForBareArgument(req.Dir, true))
}
if needsDelayedExpansion {
parts = append(parts, wrapCMDCommand(command))
} else {
parts = append(parts, command)
}
return strings.Join(parts, " && "), nil
}
func buildExecCommandRaw(req ExecRequest) (string, error) {
if strings.TrimSpace(req.Dir) != "" {
return "", errors.New("raw exec shell dialect does not support Dir")
}
if len(req.Env) > 0 {
return "", errors.New("raw exec shell dialect does not support Env")
}
return req.Command, nil
}
func combineCommandOutput(stdout string, stderr string) []byte {
if stdout == "" && stderr == "" {
return nil
}
if stdout == "" {
return []byte(stderr)
}
if stderr == "" {
return []byte(stdout)
}
return []byte(stdout + "\n" + stderr)
}
func powerShellSingleQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}
func wrapCMDCommand(script string) string {
return `cmd.exe /Q /D /V:ON /C ` + cmdEscapeForNestedCommand(script)
}
func rewriteCMDPercentVariables(command string, replacements map[string]string) (string, bool) {
if len(replacements) == 0 || command == "" {
return command, false
}
var builder strings.Builder
builder.Grow(len(command))
rewrote := false
for i := 0; i < len(command); {
if command[i] != '%' {
builder.WriteByte(command[i])
i++
continue
}
end := strings.IndexByte(command[i+1:], '%')
if end < 0 {
builder.WriteByte(command[i])
i++
continue
}
end += i + 1
name := command[i+1 : end]
if name == "" {
builder.WriteString(command[i : end+1])
i = end + 1
continue
}
replacement, ok := replacements[strings.ToUpper(name)]
if !ok {
builder.WriteString(command[i : end+1])
i = end + 1
continue
}
builder.WriteString(replacement)
rewrote = true
i = end + 1
}
return builder.String(), rewrote
}
func cmdContainsBangVariable(command string, name string) bool {
if command == "" || name == "" {
return false
}
return strings.Contains(strings.ToUpper(command), "!"+strings.ToUpper(name)+"!")
}
func cmdEscapeForSetValue(value string) string {
var builder strings.Builder
builder.Grow(len(value))
for _, char := range value {
switch char {
case '^':
builder.WriteString("^^")
case '&', '|', '<', '>', '(', ')', '"':
builder.WriteByte('^')
builder.WriteRune(char)
case '%':
builder.WriteString("%%")
default:
builder.WriteRune(char)
}
}
return builder.String()
}
func cmdEscapeForBareArgument(value string, escapeSpace bool) string {
var builder strings.Builder
builder.Grow(len(value))
for _, char := range value {
switch char {
case '^':
builder.WriteString("^^")
case '&', '|', '<', '>', '(', ')', '"':
builder.WriteByte('^')
builder.WriteRune(char)
case '%':
builder.WriteString("%%")
case ' ':
if escapeSpace {
builder.WriteString("^ ")
} else {
builder.WriteRune(char)
}
default:
builder.WriteRune(char)
}
}
return builder.String()
}
func cmdEscapeForNestedCommand(command string) string {
return cmdEscapeForBareArgument(command, false)
}
type captureBuffer struct {
limit int
discard bool
buffer bytes.Buffer
truncated bool
}
func newCaptureBuffer(limit int, discard bool) *captureBuffer {
return &captureBuffer{
limit: limit,
discard: discard,
}
}
func (b *captureBuffer) Write(data []byte) (int, error) {
if len(data) == 0 {
return 0, nil
}
if b == nil {
return len(data), nil
}
if b.discard {
return len(data), nil
}
if b.limit <= 0 {
_, _ = b.buffer.Write(data)
return len(data), nil
}
remaining := b.limit - b.buffer.Len()
if remaining <= 0 {
b.truncated = true
return len(data), nil
}
if len(data) > remaining {
_, _ = b.buffer.Write(data[:remaining])
b.truncated = true
return len(data), nil
}
_, _ = b.buffer.Write(data)
return len(data), nil
}
func (b *captureBuffer) Bytes() []byte {
if b == nil {
return nil
}
return b.buffer.Bytes()
}
func (b *captureBuffer) Truncated() bool {
if b == nil {
return false
}
return b.truncated
}
func isValidShellEnvKey(key string) bool {
if key == "" {
return false
}
for i, r := range key {
if i == 0 {
if r != '_' && !unicode.IsLetter(r) {
return false
}
continue
}
if r != '_' && !unicode.IsLetter(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func streamExecReader(reader io.Reader, isStderr bool, chunks chan<- ExecStreamChunk, errCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
chunks <- ExecStreamChunk{
Data: chunk,
Stderr: isStderr,
}
}
if err == io.EOF {
return
}
if err != nil {
errCh <- err
return
}
}
}
type execChunkDispatcher struct {
onChunk func(ExecStreamChunk)
done chan struct{}
mu sync.Mutex
cond *sync.Cond
queue []ExecStreamChunk
queueBytes int
maxChunks int
maxBytes int
strategy ExecStreamOverflowStrategy
closed bool
stopped bool
failed bool
droppedBytes int
droppedCount int
}
type execChunkDispatchStats struct {
droppedChunks int
droppedBytes int
err error
}
func newExecChunkDispatcher(req ExecRequest, onChunk func(ExecStreamChunk)) (*execChunkDispatcher, error) {
if onChunk == nil {
return nil, nil
}
config, err := normalizeExecChunkDispatchConfig(req)
if err != nil {
return nil, err
}
dispatcher := &execChunkDispatcher{
onChunk: onChunk,
done: make(chan struct{}),
maxChunks: config.maxChunks,
maxBytes: config.maxBytes,
strategy: config.strategy,
}
dispatcher.cond = sync.NewCond(&dispatcher.mu)
go dispatcher.run()
return dispatcher, nil
}
func (d *execChunkDispatcher) Enqueue(chunk ExecStreamChunk) {
if d == nil || len(chunk.Data) == 0 {
return
}
d.mu.Lock()
defer d.mu.Unlock()
if d.closed || d.stopped {
d.recordDropLocked(chunk)
return
}
switch d.strategy {
case ExecStreamOverflowDropOldest:
for len(d.queue) > 0 && d.wouldOverflowLocked(chunk) {
d.recordDropLocked(d.popOldestLocked())
}
if d.wouldOverflowLocked(chunk) {
d.recordDropLocked(chunk)
return
}
case ExecStreamOverflowDropNewest:
if d.wouldOverflowLocked(chunk) {
d.recordDropLocked(chunk)
return
}
case ExecStreamOverflowFail:
if d.wouldOverflowLocked(chunk) {
d.recordDropLocked(chunk)
d.stopWithOverflowLocked()
return
}
}
d.queue = append(d.queue, chunk)
d.queueBytes += len(chunk.Data)
d.cond.Signal()
}
func (d *execChunkDispatcher) Close() {
if d == nil {
return
}
d.mu.Lock()
if d.closed {
d.mu.Unlock()
return
}
d.closed = true
d.cond.Broadcast()
d.mu.Unlock()
}
func (d *execChunkDispatcher) run() {
defer close(d.done)
for {
chunk, ok := d.next()
if !ok {
return
}
d.onChunk(chunk)
}
}
func (d *execChunkDispatcher) next() (ExecStreamChunk, bool) {
d.mu.Lock()
defer d.mu.Unlock()
for len(d.queue) == 0 && !d.closed && !d.stopped {
d.cond.Wait()
}
if len(d.queue) == 0 {
return ExecStreamChunk{}, false
}
chunk := d.popOldestLocked()
return chunk, true
}
func (d *execChunkDispatcher) Wait() execChunkDispatchStats {
if d == nil {
return execChunkDispatchStats{}
}
<-d.done
d.mu.Lock()
defer d.mu.Unlock()
return execChunkDispatchStats{
droppedChunks: d.droppedCount,
droppedBytes: d.droppedBytes,
err: d.dispatchErrorLocked(),
}
}
func (d *execChunkDispatcher) wouldOverflowLocked(chunk ExecStreamChunk) bool {
if d.maxChunks > 0 && len(d.queue)+1 > d.maxChunks {
return true
}
if d.maxBytes > 0 && d.queueBytes+len(chunk.Data) > d.maxBytes {
return true
}
return false
}
func (d *execChunkDispatcher) popOldestLocked() ExecStreamChunk {
if len(d.queue) == 0 {
return ExecStreamChunk{}
}
chunk := d.queue[0]
d.queue[0] = ExecStreamChunk{}
d.queue = d.queue[1:]
d.queueBytes -= len(chunk.Data)
if d.queueBytes < 0 {
d.queueBytes = 0
}
if len(d.queue) == 0 {
d.queue = nil
}
return chunk
}
func (d *execChunkDispatcher) recordDropLocked(chunk ExecStreamChunk) {
if len(chunk.Data) == 0 {
return
}
d.droppedCount++
d.droppedBytes += len(chunk.Data)
}
func (d *execChunkDispatcher) stopWithOverflowLocked() {
if d.stopped {
return
}
for len(d.queue) > 0 {
d.recordDropLocked(d.popOldestLocked())
}
d.stopped = true
d.failed = true
d.cond.Broadcast()
}
func (d *execChunkDispatcher) dispatchErrorLocked() error {
if !d.failed {
return nil
}
return &ExecStreamOverflowError{
DroppedChunks: d.droppedCount,
DroppedBytes: d.droppedBytes,
}
}
type execChunkDispatchConfig struct {
maxChunks int
maxBytes int
strategy ExecStreamOverflowStrategy
}
func normalizeExecChunkDispatchConfig(req ExecRequest) (execChunkDispatchConfig, error) {
config := execChunkDispatchConfig{
maxChunks: req.StreamMaxPendingChunks,
maxBytes: req.StreamMaxPendingBytes,
strategy: req.StreamOverflowStrategy,
}
if config.maxChunks <= 0 {
config.maxChunks = defaultExecStreamMaxPendingChunks
}
if config.maxBytes <= 0 {
config.maxBytes = defaultExecStreamMaxPendingBytes
}
if config.strategy == "" {
config.strategy = ExecStreamOverflowDropOldest
}
switch config.strategy {
case ExecStreamOverflowDropOldest, ExecStreamOverflowDropNewest, ExecStreamOverflowFail:
return config, nil
default:
return execChunkDispatchConfig{}, fmt.Errorf("invalid exec stream overflow strategy %q", req.StreamOverflowStrategy)
}
}
func firstExecError(errCh <-chan error) error {
for err := range errCh {
if err != nil {
return err
}
}
return nil
}
func (s *StarSSH) ShellOne(cmd string) (string, error) {
result, err := s.Exec(context.Background(), ExecRequest{
Command: cmd,
})
if err != nil {
return "", err
}
combined := strings.TrimSpace(result.CombinedString())
if cmdErr := result.CommandError(); cmdErr != nil {
return combined, cmdErr
}
return combined, nil
}
func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) {
return s.streamCommand(cmd, func(chunk string) {
fmt.Print(chunk)
})
}
func (s *StarSSH) ShellOneShowScreenResult(cmd string) (*ExecResult, error) {
return s.streamCommandResult(cmd, func(chunk string) {
fmt.Print(chunk)
})
}
func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) {
return s.streamCommand(cmd, callback)
}
func (s *StarSSH) ShellOneToFuncResult(cmd string, callback func(string)) (*ExecResult, error) {
return s.streamCommandResult(cmd, callback)
}
func (s *StarSSH) streamCommand(cmd string, onChunk func(string)) (string, error) {
result, err := s.streamCommandResult(cmd, onChunk)
stdoutText := strings.TrimSpace(resultStdoutString(result))
if err != nil {
return stdoutText, err
}
return stdoutText, streamCommandLegacyError(result)
}
func (s *StarSSH) streamCommandResult(cmd string, onChunk func(string)) (*ExecResult, error) {
result, err := s.ExecStream(context.Background(), ExecRequest{
Command: cmd,
}, func(chunk ExecStreamChunk) {
if onChunk != nil {
onChunk(string(chunk.Data))
}
})
return result, err
}
func streamCommandLegacyError(result *ExecResult) error {
if result == nil {
return nil
}
return errors.Join(result.CommandError(), result.StreamOverflowError())
}
func resultStdoutString(result *ExecResult) string {
if result == nil {
return ""
}
return result.StdoutString()
}
func (s *StarSSH) Exists(filepath string) bool {
return s.remotePathProbe(legacyPathProbeExists, filepath)
}
func (s *StarSSH) IsFile(filepath string) bool {
return s.remotePathProbe(legacyPathProbeFile, filepath)
}
func (s *StarSSH) IsFolder(filepath string) bool {
return s.remotePathProbe(legacyPathProbeDirectory, filepath)
}
func (s *StarSSH) GetUid() string {
return s.remoteIdentityProbe(legacyIdentityProbeUID)
}
func (s *StarSSH) GetGid() string {
return s.remoteIdentityProbe(legacyIdentityProbeGID)
}
func (s *StarSSH) GetUser() string {
return s.remoteIdentityProbe(legacyIdentityProbeUser)
}
func (s *StarSSH) GetGroup() string {
return s.remoteIdentityProbe(legacyIdentityProbeGroup)
}
type legacyPathProbeKind string
const (
legacyPathProbeExists legacyPathProbeKind = "exists"
legacyPathProbeFile legacyPathProbeKind = "file"
legacyPathProbeDirectory legacyPathProbeKind = "directory"
)
type legacyIdentityProbeKind string
const (
legacyIdentityProbeUID legacyIdentityProbeKind = "uid"
legacyIdentityProbeGID legacyIdentityProbeKind = "gid"
legacyIdentityProbeUser legacyIdentityProbeKind = "user"
legacyIdentityProbeGroup legacyIdentityProbeKind = "group"
)
func (s *StarSSH) remotePathProbe(kind legacyPathProbeKind, filepath string) bool {
result, ok := s.tryLegacyProbeRequests(buildLegacyPathProbeRequests(kind, filepath))
return ok && strings.TrimSpace(result) == "1"
}
func (s *StarSSH) remoteIdentityProbe(kind legacyIdentityProbeKind) string {
result, ok := s.tryLegacyProbeRequests(buildLegacyIdentityProbeRequests(kind))
if !ok {
return ""
}
return strings.TrimSpace(result)
}
func (s *StarSSH) tryLegacyProbeRequests(requests []ExecRequest) (string, bool) {
if s == nil {
return "", false
}
for _, req := range requests {
result, err := execRequestRunner(s, context.Background(), req)
if err != nil || result == nil || result.CommandError() != nil {
continue
}
return result.StdoutString(), true
}
return "", false
}
func buildLegacyPathProbeRequests(kind legacyPathProbeKind, filepath string) []ExecRequest {
requests := []ExecRequest{
buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXPathProbeScript(kind, filepath))),
}
for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} {
requests = append(requests, buildLegacyRawExecRequest(
wrapPowerShellRawCommand(executable, buildLegacyPowerShellPathProbeScript(kind, filepath)),
))
}
requests = append(requests, buildLegacyRawExecRequest(buildLegacyCMDPathProbeCommand(kind, filepath)))
return requests
}
func buildLegacyIdentityProbeRequests(kind legacyIdentityProbeKind) []ExecRequest {
requests := []ExecRequest{
buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXIdentityProbeScript(kind))),
}
for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} {
requests = append(requests, buildLegacyRawExecRequest(
wrapPowerShellRawCommand(executable, buildLegacyPowerShellIdentityProbeScript(kind)),
))
}
if kind == legacyIdentityProbeUser {
requests = append(requests, buildLegacyRawExecRequest("cmd.exe /Q /D /C whoami"))
}
return requests
}
func buildLegacyRawExecRequest(command string) ExecRequest {
return ExecRequest{
Command: command,
ShellDialect: ExecShellDialectRaw,
}
}
func wrapPOSIXRawCommand(script string) string {
return "sh -lc " + shellSingleQuote(script)
}
func wrapPowerShellRawCommand(executable string, script string) string {
return executable + " -NoLogo -NoProfile -NonInteractive -Command " + powerShellSingleQuote(script)
}
func buildLegacyPOSIXPathProbeScript(kind legacyPathProbeKind, filepath string) string {
flag := "-e"
switch kind {
case legacyPathProbeFile:
flag = "-f"
case legacyPathProbeDirectory:
flag = "-d"
}
return fmt.Sprintf("if [ %s -- %s ]; then printf '1\\n'; else printf '0\\n'; fi", flag, shellSingleQuote(filepath))
}
func buildLegacyPowerShellPathProbeScript(kind legacyPathProbeKind, filepath string) string {
condition := "Test-Path -LiteralPath " + powerShellSingleQuote(filepath)
switch kind {
case legacyPathProbeFile:
condition += " -PathType Leaf"
case legacyPathProbeDirectory:
condition += " -PathType Container"
}
return "if (" + condition + ") { Write-Output '1' } else { Write-Output '0' }"
}
func buildLegacyCMDPathProbeCommand(kind legacyPathProbeKind, filepath string) string {
parts := []string{
"setlocal DisableDelayedExpansion",
`set "STARSSH_PATH=` + cmdEscapeForSetValue(filepath) + `"`,
}
switch kind {
case legacyPathProbeExists:
parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (echo 1) else echo 0`))
case legacyPathProbeFile:
parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (if exist "!STARSSH_PATH!\NUL" (echo 0) else echo 1) else echo 0`))
case legacyPathProbeDirectory:
parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!\NUL" (echo 1) else echo 0`))
}
return strings.Join(parts, " && ")
}
func buildLegacyPOSIXIdentityProbeScript(kind legacyIdentityProbeKind) string {
switch kind {
case legacyIdentityProbeUID:
return "id -u"
case legacyIdentityProbeGID:
return "id -g"
case legacyIdentityProbeUser:
return "id -un"
case legacyIdentityProbeGroup:
return "id -gn"
default:
return ""
}
}
func buildLegacyPowerShellIdentityProbeScript(kind legacyIdentityProbeKind) string {
switch kind {
case legacyIdentityProbeUID:
return "[System.Security.Principal.WindowsIdentity]::GetCurrent().User.Value"
case legacyIdentityProbeGID:
return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { $group.Value }"
case legacyIdentityProbeUser:
return "$env:USERNAME"
case legacyIdentityProbeGroup:
return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { try { $group.Translate([System.Security.Principal.NTAccount]).Value } catch { $group.Value } }"
default:
return ""
}
}