package starssh import ( "bytes" "context" "errors" "fmt" "io" "sort" "strconv" "strings" "sync" "time" "unicode" "golang.org/x/crypto/ssh" ) type ExecRequest struct { Command string Stdin []byte Env map[string]string Dir string ShellDialect ExecShellDialect Timeout time.Duration PTY *TerminalConfig DiscardOutput bool MaxOutputBytes int StreamMaxPendingChunks int StreamMaxPendingBytes int StreamOverflowStrategy ExecStreamOverflowStrategy } type ExecResult struct { Command string Stdout []byte Stderr []byte Combined []byte StdoutTruncated bool StderrTruncated bool CombinedTruncated bool StreamDroppedChunks int StreamDroppedBytes int ExitCode int ExitSignal string ExitMessage string Duration time.Duration } type ExecStreamChunk struct { Data []byte Stderr bool } type ExecShellDialect string const ( ExecShellDialectPOSIX ExecShellDialect = "posix" ExecShellDialectPowerShell ExecShellDialect = "powershell" ExecShellDialectCMD ExecShellDialect = "cmd" ExecShellDialectRaw ExecShellDialect = "raw" ) type ExecStreamOverflowStrategy string const ( ExecStreamOverflowDropOldest ExecStreamOverflowStrategy = "drop_oldest" ExecStreamOverflowDropNewest ExecStreamOverflowStrategy = "drop_newest" ExecStreamOverflowFail ExecStreamOverflowStrategy = "fail" ) type ExecExitError struct { Status int Signal string Message string Stderr string } type ExecStreamOverflowError struct { DroppedChunks int DroppedBytes int } type ShellExitError = ExecExitError var execRequestRunner = func(s *StarSSH, ctx context.Context, req ExecRequest) (*ExecResult, error) { return s.Exec(ctx, req) } func (e *ExecExitError) Error() string { if e == nil { return "" } base := "remote command exited" if e.Status != 0 { base += " with status " + strconv.Itoa(e.Status) } if e.Signal != "" { base += " from signal " + e.Signal } if e.Message != "" { base += ": " + e.Message } if e.Stderr != "" { base += ": " + e.Stderr } return base } func (e *ExecExitError) ExitStatus() int { if e == nil { return 0 } return e.Status } func (e *ExecStreamOverflowError) Error() string { if e == nil { return "" } return fmt.Sprintf("exec stream callback queue overflow: dropped %d chunks (%d bytes)", e.DroppedChunks, e.DroppedBytes) } func (r *ExecResult) Success() bool { return r != nil && r.ExitCode == 0 && r.ExitSignal == "" } func (r *ExecResult) StdoutString() string { if r == nil { return "" } return string(r.Stdout) } func (r *ExecResult) StderrString() string { if r == nil { return "" } return string(r.Stderr) } func (r *ExecResult) CombinedString() string { if r == nil { return "" } return string(r.Combined) } func (r *ExecResult) CommandError() error { if r == nil || r.Success() { return nil } return &ExecExitError{ Status: r.ExitCode, Signal: r.ExitSignal, Message: strings.TrimSpace(r.ExitMessage), Stderr: strings.TrimSpace(r.StderrString()), } } func (r *ExecResult) OutputTruncated() bool { if r == nil { return false } return r.StdoutTruncated || r.StderrTruncated || r.CombinedTruncated } func (r *ExecResult) StreamOutputDropped() bool { return r != nil && (r.StreamDroppedChunks > 0 || r.StreamDroppedBytes > 0) } func (r *ExecResult) StreamOverflowError() error { if r == nil || !r.StreamOutputDropped() { return nil } return &ExecStreamOverflowError{ DroppedChunks: r.StreamDroppedChunks, DroppedBytes: r.StreamDroppedBytes, } } func (s *StarSSH) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { return s.exec(ctx, req, nil) } func (s *StarSSH) ExecString(ctx context.Context, command string) (*ExecResult, error) { return s.Exec(ctx, ExecRequest{ Command: command, }) } func (s *StarSSH) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { return s.exec(ctx, req, onChunk) } func (s *StarSSH) exec(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { if s == nil { return nil, errors.New("ssh client is nil") } if ctx == nil { ctx = context.Background() } if req.Timeout > 0 { timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout) defer cancel() ctx = timeoutCtx } remoteCommand, err := buildExecCommand(req) if err != nil { return nil, err } session, err := s.newExecRuntimeSession(req) if err != nil { return nil, err } defer session.Close() var stdin io.WriteCloser if req.Stdin != nil { stdin, err = session.StdinPipe() if err != nil { return nil, err } } stdout, err := session.StdoutPipe() if err != nil { return nil, err } stderr, err := session.StderrPipe() if err != nil { return nil, err } result := &ExecResult{ Command: req.Command, } startAt := time.Now() if err := session.Start(remoteCommand); err != nil { return nil, err } if stdin != nil { go func() { _, _ = stdin.Write(req.Stdin) _ = stdin.Close() }() } chunks := make(chan ExecStreamChunk, 16) readErrs := make(chan error, 2) var readWG sync.WaitGroup readWG.Add(2) go streamExecReader(stdout, false, chunks, readErrs, &readWG) go streamExecReader(stderr, true, chunks, readErrs, &readWG) go func() { readWG.Wait() close(chunks) close(readErrs) }() stdoutBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) stderrBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) combinedBuf := newCaptureBuffer(req.MaxOutputBytes, req.DiscardOutput) dispatcher, err := newExecChunkDispatcher(req, onChunk) if err != nil { return nil, err } drainDone := make(chan struct{}) go func() { defer close(drainDone) defer dispatcher.Close() for chunk := range chunks { if len(chunk.Data) == 0 { continue } _, _ = combinedBuf.Write(chunk.Data) if chunk.Stderr { _, _ = stderrBuf.Write(chunk.Data) } else { _, _ = stdoutBuf.Write(chunk.Data) } dispatcher.Enqueue(chunk) } }() waitCh := make(chan error, 1) go func() { waitCh <- session.Wait() }() var waitErr error select { case waitErr = <-waitCh: case <-ctx.Done(): _ = session.Close() waitErr = ctx.Err() } <-drainDone dispatchStats := dispatcher.Wait() readErr := firstExecError(readErrs) result.Stdout = append(result.Stdout[:0], stdoutBuf.Bytes()...) result.Stderr = append(result.Stderr[:0], stderrBuf.Bytes()...) result.Combined = append(result.Combined[:0], combinedBuf.Bytes()...) result.StdoutTruncated = stdoutBuf.Truncated() result.StderrTruncated = stderrBuf.Truncated() result.CombinedTruncated = combinedBuf.Truncated() result.StreamDroppedChunks = dispatchStats.droppedChunks result.StreamDroppedBytes = dispatchStats.droppedBytes result.Duration = time.Since(startAt) if errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded) { return result, waitErr } var exitErr *ssh.ExitError if errors.As(waitErr, &exitErr) { result.ExitCode = exitErr.ExitStatus() result.ExitSignal = exitErr.Signal() result.ExitMessage = exitErr.Msg() waitErr = nil } if readErr != nil { return result, readErr } if dispatchStats.err != nil { return result, dispatchStats.err } if waitErr == nil { return result, nil } return result, waitErr } func (s *StarSSH) newExecRuntimeSession(req ExecRequest) (*ssh.Session, error) { if req.PTY != nil { return s.NewPTYSession(req.PTY) } return s.NewExecSession() } func buildExecCommand(req ExecRequest) (string, error) { if strings.TrimSpace(req.Command) == "" { return "", errors.New("command is empty") } dialect, err := normalizeExecShellDialect(req.ShellDialect) if err != nil { return "", err } switch dialect { case ExecShellDialectPOSIX: return buildExecCommandPOSIX(req) case ExecShellDialectPowerShell: return buildExecCommandPowerShell(req) case ExecShellDialectCMD: return buildExecCommandCMD(req) case ExecShellDialectRaw: return buildExecCommandRaw(req) default: return "", fmt.Errorf("unsupported exec shell dialect %q", req.ShellDialect) } } func normalizeExecShellDialect(dialect ExecShellDialect) (ExecShellDialect, error) { if strings.TrimSpace(string(dialect)) == "" { return ExecShellDialectPOSIX, nil } switch ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))) { case ExecShellDialectPOSIX, ExecShellDialectPowerShell, ExecShellDialectCMD, ExecShellDialectRaw: return ExecShellDialect(strings.ToLower(strings.TrimSpace(string(dialect)))), nil default: return "", fmt.Errorf("invalid exec shell dialect %q", dialect) } } func buildExecCommandPOSIX(req ExecRequest) (string, error) { parts := make([]string, 0, 3) if strings.TrimSpace(req.Dir) != "" { parts = append(parts, "cd "+shellSingleQuote(req.Dir)) } if len(req.Env) > 0 { keys := make([]string, 0, len(req.Env)) for key := range req.Env { if !isValidShellEnvKey(key) { return "", fmt.Errorf("invalid env key %q", key) } keys = append(keys, key) } sort.Strings(keys) assignments := make([]string, 0, len(keys)) for _, key := range keys { assignments = append(assignments, key+"="+shellSingleQuote(req.Env[key])) } parts = append(parts, "export "+strings.Join(assignments, " ")) } parts = append(parts, req.Command) return strings.Join(parts, " && "), nil } func buildExecCommandPowerShell(req ExecRequest) (string, error) { parts := make([]string, 0, 2+len(req.Env)) parts = append(parts, "$ErrorActionPreference = 'Stop'") if strings.TrimSpace(req.Dir) != "" { parts = append(parts, "Set-Location -LiteralPath "+powerShellSingleQuote(req.Dir)) } if len(req.Env) > 0 { keys := make([]string, 0, len(req.Env)) for key := range req.Env { if !isValidShellEnvKey(key) { return "", fmt.Errorf("invalid env key %q", key) } keys = append(keys, key) } sort.Strings(keys) for _, key := range keys { parts = append(parts, "$env:"+key+" = "+powerShellSingleQuote(req.Env[key])) } } parts = append(parts, req.Command) return strings.Join(parts, "; "), nil } func buildExecCommandCMD(req ExecRequest) (string, error) { keys := make([]string, 0, len(req.Env)) for key := range req.Env { if !isValidShellEnvKey(key) { return "", fmt.Errorf("invalid env key %q", key) } keys = append(keys, key) } sort.Strings(keys) replacements := make(map[string]string, len(keys)+1) if strings.TrimSpace(req.Dir) != "" { replacements["CD"] = "!CD!" } for _, key := range keys { replacements[strings.ToUpper(key)] = "!" + key + "!" } command, rewrotePercentVars := rewriteCMDPercentVariables(req.Command, replacements) needsDelayedExpansion := rewrotePercentVars if strings.TrimSpace(req.Dir) != "" && cmdContainsBangVariable(command, "CD") { needsDelayedExpansion = true } for _, key := range keys { if cmdContainsBangVariable(command, key) { needsDelayedExpansion = true } } parts := make([]string, 0, 3+len(keys)) if len(keys) > 0 { parts = append(parts, "setlocal DisableDelayedExpansion") for _, key := range keys { parts = append(parts, "set "+key+"="+cmdEscapeForSetValue(req.Env[key])) } } if strings.TrimSpace(req.Dir) != "" { parts = append(parts, "cd /d "+cmdEscapeForBareArgument(req.Dir, true)) } if needsDelayedExpansion { parts = append(parts, wrapCMDCommand(command)) } else { parts = append(parts, command) } return strings.Join(parts, " && "), nil } func buildExecCommandRaw(req ExecRequest) (string, error) { if strings.TrimSpace(req.Dir) != "" { return "", errors.New("raw exec shell dialect does not support Dir") } if len(req.Env) > 0 { return "", errors.New("raw exec shell dialect does not support Env") } return req.Command, nil } func combineCommandOutput(stdout string, stderr string) []byte { if stdout == "" && stderr == "" { return nil } if stdout == "" { return []byte(stderr) } if stderr == "" { return []byte(stdout) } return []byte(stdout + "\n" + stderr) } func powerShellSingleQuote(s string) string { return "'" + strings.ReplaceAll(s, "'", "''") + "'" } func wrapCMDCommand(script string) string { return `cmd.exe /Q /D /V:ON /C ` + cmdEscapeForNestedCommand(script) } func rewriteCMDPercentVariables(command string, replacements map[string]string) (string, bool) { if len(replacements) == 0 || command == "" { return command, false } var builder strings.Builder builder.Grow(len(command)) rewrote := false for i := 0; i < len(command); { if command[i] != '%' { builder.WriteByte(command[i]) i++ continue } end := strings.IndexByte(command[i+1:], '%') if end < 0 { builder.WriteByte(command[i]) i++ continue } end += i + 1 name := command[i+1 : end] if name == "" { builder.WriteString(command[i : end+1]) i = end + 1 continue } replacement, ok := replacements[strings.ToUpper(name)] if !ok { builder.WriteString(command[i : end+1]) i = end + 1 continue } builder.WriteString(replacement) rewrote = true i = end + 1 } return builder.String(), rewrote } func cmdContainsBangVariable(command string, name string) bool { if command == "" || name == "" { return false } return strings.Contains(strings.ToUpper(command), "!"+strings.ToUpper(name)+"!") } func cmdEscapeForSetValue(value string) string { var builder strings.Builder builder.Grow(len(value)) for _, char := range value { switch char { case '^': builder.WriteString("^^") case '&', '|', '<', '>', '(', ')', '"': builder.WriteByte('^') builder.WriteRune(char) case '%': builder.WriteString("%%") default: builder.WriteRune(char) } } return builder.String() } func cmdEscapeForBareArgument(value string, escapeSpace bool) string { var builder strings.Builder builder.Grow(len(value)) for _, char := range value { switch char { case '^': builder.WriteString("^^") case '&', '|', '<', '>', '(', ')', '"': builder.WriteByte('^') builder.WriteRune(char) case '%': builder.WriteString("%%") case ' ': if escapeSpace { builder.WriteString("^ ") } else { builder.WriteRune(char) } default: builder.WriteRune(char) } } return builder.String() } func cmdEscapeForNestedCommand(command string) string { return cmdEscapeForBareArgument(command, false) } type captureBuffer struct { limit int discard bool buffer bytes.Buffer truncated bool } func newCaptureBuffer(limit int, discard bool) *captureBuffer { return &captureBuffer{ limit: limit, discard: discard, } } func (b *captureBuffer) Write(data []byte) (int, error) { if len(data) == 0 { return 0, nil } if b == nil { return len(data), nil } if b.discard { return len(data), nil } if b.limit <= 0 { _, _ = b.buffer.Write(data) return len(data), nil } remaining := b.limit - b.buffer.Len() if remaining <= 0 { b.truncated = true return len(data), nil } if len(data) > remaining { _, _ = b.buffer.Write(data[:remaining]) b.truncated = true return len(data), nil } _, _ = b.buffer.Write(data) return len(data), nil } func (b *captureBuffer) Bytes() []byte { if b == nil { return nil } return b.buffer.Bytes() } func (b *captureBuffer) Truncated() bool { if b == nil { return false } return b.truncated } func isValidShellEnvKey(key string) bool { if key == "" { return false } for i, r := range key { if i == 0 { if r != '_' && !unicode.IsLetter(r) { return false } continue } if r != '_' && !unicode.IsLetter(r) && !unicode.IsDigit(r) { return false } } return true } func streamExecReader(reader io.Reader, isStderr bool, chunks chan<- ExecStreamChunk, errCh chan<- error, wg *sync.WaitGroup) { defer wg.Done() buf := make([]byte, 4096) for { n, err := reader.Read(buf) if n > 0 { chunk := make([]byte, n) copy(chunk, buf[:n]) chunks <- ExecStreamChunk{ Data: chunk, Stderr: isStderr, } } if err == io.EOF { return } if err != nil { errCh <- err return } } } type execChunkDispatcher struct { onChunk func(ExecStreamChunk) done chan struct{} mu sync.Mutex cond *sync.Cond queue []ExecStreamChunk queueBytes int maxChunks int maxBytes int strategy ExecStreamOverflowStrategy closed bool stopped bool failed bool droppedBytes int droppedCount int } type execChunkDispatchStats struct { droppedChunks int droppedBytes int err error } func newExecChunkDispatcher(req ExecRequest, onChunk func(ExecStreamChunk)) (*execChunkDispatcher, error) { if onChunk == nil { return nil, nil } config, err := normalizeExecChunkDispatchConfig(req) if err != nil { return nil, err } dispatcher := &execChunkDispatcher{ onChunk: onChunk, done: make(chan struct{}), maxChunks: config.maxChunks, maxBytes: config.maxBytes, strategy: config.strategy, } dispatcher.cond = sync.NewCond(&dispatcher.mu) go dispatcher.run() return dispatcher, nil } func (d *execChunkDispatcher) Enqueue(chunk ExecStreamChunk) { if d == nil || len(chunk.Data) == 0 { return } d.mu.Lock() defer d.mu.Unlock() if d.closed || d.stopped { d.recordDropLocked(chunk) return } switch d.strategy { case ExecStreamOverflowDropOldest: for len(d.queue) > 0 && d.wouldOverflowLocked(chunk) { d.recordDropLocked(d.popOldestLocked()) } if d.wouldOverflowLocked(chunk) { d.recordDropLocked(chunk) return } case ExecStreamOverflowDropNewest: if d.wouldOverflowLocked(chunk) { d.recordDropLocked(chunk) return } case ExecStreamOverflowFail: if d.wouldOverflowLocked(chunk) { d.recordDropLocked(chunk) d.stopWithOverflowLocked() return } } d.queue = append(d.queue, chunk) d.queueBytes += len(chunk.Data) d.cond.Signal() } func (d *execChunkDispatcher) Close() { if d == nil { return } d.mu.Lock() if d.closed { d.mu.Unlock() return } d.closed = true d.cond.Broadcast() d.mu.Unlock() } func (d *execChunkDispatcher) run() { defer close(d.done) for { chunk, ok := d.next() if !ok { return } d.onChunk(chunk) } } func (d *execChunkDispatcher) next() (ExecStreamChunk, bool) { d.mu.Lock() defer d.mu.Unlock() for len(d.queue) == 0 && !d.closed && !d.stopped { d.cond.Wait() } if len(d.queue) == 0 { return ExecStreamChunk{}, false } chunk := d.popOldestLocked() return chunk, true } func (d *execChunkDispatcher) Wait() execChunkDispatchStats { if d == nil { return execChunkDispatchStats{} } <-d.done d.mu.Lock() defer d.mu.Unlock() return execChunkDispatchStats{ droppedChunks: d.droppedCount, droppedBytes: d.droppedBytes, err: d.dispatchErrorLocked(), } } func (d *execChunkDispatcher) wouldOverflowLocked(chunk ExecStreamChunk) bool { if d.maxChunks > 0 && len(d.queue)+1 > d.maxChunks { return true } if d.maxBytes > 0 && d.queueBytes+len(chunk.Data) > d.maxBytes { return true } return false } func (d *execChunkDispatcher) popOldestLocked() ExecStreamChunk { if len(d.queue) == 0 { return ExecStreamChunk{} } chunk := d.queue[0] d.queue[0] = ExecStreamChunk{} d.queue = d.queue[1:] d.queueBytes -= len(chunk.Data) if d.queueBytes < 0 { d.queueBytes = 0 } if len(d.queue) == 0 { d.queue = nil } return chunk } func (d *execChunkDispatcher) recordDropLocked(chunk ExecStreamChunk) { if len(chunk.Data) == 0 { return } d.droppedCount++ d.droppedBytes += len(chunk.Data) } func (d *execChunkDispatcher) stopWithOverflowLocked() { if d.stopped { return } for len(d.queue) > 0 { d.recordDropLocked(d.popOldestLocked()) } d.stopped = true d.failed = true d.cond.Broadcast() } func (d *execChunkDispatcher) dispatchErrorLocked() error { if !d.failed { return nil } return &ExecStreamOverflowError{ DroppedChunks: d.droppedCount, DroppedBytes: d.droppedBytes, } } type execChunkDispatchConfig struct { maxChunks int maxBytes int strategy ExecStreamOverflowStrategy } func normalizeExecChunkDispatchConfig(req ExecRequest) (execChunkDispatchConfig, error) { config := execChunkDispatchConfig{ maxChunks: req.StreamMaxPendingChunks, maxBytes: req.StreamMaxPendingBytes, strategy: req.StreamOverflowStrategy, } if config.maxChunks <= 0 { config.maxChunks = defaultExecStreamMaxPendingChunks } if config.maxBytes <= 0 { config.maxBytes = defaultExecStreamMaxPendingBytes } if config.strategy == "" { config.strategy = ExecStreamOverflowDropOldest } switch config.strategy { case ExecStreamOverflowDropOldest, ExecStreamOverflowDropNewest, ExecStreamOverflowFail: return config, nil default: return execChunkDispatchConfig{}, fmt.Errorf("invalid exec stream overflow strategy %q", req.StreamOverflowStrategy) } } func firstExecError(errCh <-chan error) error { for err := range errCh { if err != nil { return err } } return nil } func (s *StarSSH) ShellOne(cmd string) (string, error) { result, err := s.Exec(context.Background(), ExecRequest{ Command: cmd, }) if err != nil { return "", err } combined := strings.TrimSpace(result.CombinedString()) if cmdErr := result.CommandError(); cmdErr != nil { return combined, cmdErr } return combined, nil } func (s *StarSSH) ShellOneShowScreen(cmd string) (string, error) { return s.streamCommand(cmd, func(chunk string) { fmt.Print(chunk) }) } func (s *StarSSH) ShellOneShowScreenResult(cmd string) (*ExecResult, error) { return s.streamCommandResult(cmd, func(chunk string) { fmt.Print(chunk) }) } func (s *StarSSH) ShellOneToFunc(cmd string, callback func(string)) (string, error) { return s.streamCommand(cmd, callback) } func (s *StarSSH) ShellOneToFuncResult(cmd string, callback func(string)) (*ExecResult, error) { return s.streamCommandResult(cmd, callback) } func (s *StarSSH) streamCommand(cmd string, onChunk func(string)) (string, error) { result, err := s.streamCommandResult(cmd, onChunk) stdoutText := strings.TrimSpace(resultStdoutString(result)) if err != nil { return stdoutText, err } return stdoutText, streamCommandLegacyError(result) } func (s *StarSSH) streamCommandResult(cmd string, onChunk func(string)) (*ExecResult, error) { result, err := s.ExecStream(context.Background(), ExecRequest{ Command: cmd, }, func(chunk ExecStreamChunk) { if onChunk != nil { onChunk(string(chunk.Data)) } }) return result, err } func streamCommandLegacyError(result *ExecResult) error { if result == nil { return nil } return errors.Join(result.CommandError(), result.StreamOverflowError()) } func resultStdoutString(result *ExecResult) string { if result == nil { return "" } return result.StdoutString() } func (s *StarSSH) Exists(filepath string) bool { return s.remotePathProbe(legacyPathProbeExists, filepath) } func (s *StarSSH) IsFile(filepath string) bool { return s.remotePathProbe(legacyPathProbeFile, filepath) } func (s *StarSSH) IsFolder(filepath string) bool { return s.remotePathProbe(legacyPathProbeDirectory, filepath) } func (s *StarSSH) GetUid() string { return s.remoteIdentityProbe(legacyIdentityProbeUID) } func (s *StarSSH) GetGid() string { return s.remoteIdentityProbe(legacyIdentityProbeGID) } func (s *StarSSH) GetUser() string { return s.remoteIdentityProbe(legacyIdentityProbeUser) } func (s *StarSSH) GetGroup() string { return s.remoteIdentityProbe(legacyIdentityProbeGroup) } type legacyPathProbeKind string const ( legacyPathProbeExists legacyPathProbeKind = "exists" legacyPathProbeFile legacyPathProbeKind = "file" legacyPathProbeDirectory legacyPathProbeKind = "directory" ) type legacyIdentityProbeKind string const ( legacyIdentityProbeUID legacyIdentityProbeKind = "uid" legacyIdentityProbeGID legacyIdentityProbeKind = "gid" legacyIdentityProbeUser legacyIdentityProbeKind = "user" legacyIdentityProbeGroup legacyIdentityProbeKind = "group" ) func (s *StarSSH) remotePathProbe(kind legacyPathProbeKind, filepath string) bool { result, ok := s.tryLegacyProbeRequests(buildLegacyPathProbeRequests(kind, filepath)) return ok && strings.TrimSpace(result) == "1" } func (s *StarSSH) remoteIdentityProbe(kind legacyIdentityProbeKind) string { result, ok := s.tryLegacyProbeRequests(buildLegacyIdentityProbeRequests(kind)) if !ok { return "" } return strings.TrimSpace(result) } func (s *StarSSH) tryLegacyProbeRequests(requests []ExecRequest) (string, bool) { if s == nil { return "", false } for _, req := range requests { result, err := execRequestRunner(s, context.Background(), req) if err != nil || result == nil || result.CommandError() != nil { continue } return result.StdoutString(), true } return "", false } func buildLegacyPathProbeRequests(kind legacyPathProbeKind, filepath string) []ExecRequest { requests := []ExecRequest{ buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXPathProbeScript(kind, filepath))), } for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} { requests = append(requests, buildLegacyRawExecRequest( wrapPowerShellRawCommand(executable, buildLegacyPowerShellPathProbeScript(kind, filepath)), )) } requests = append(requests, buildLegacyRawExecRequest(buildLegacyCMDPathProbeCommand(kind, filepath))) return requests } func buildLegacyIdentityProbeRequests(kind legacyIdentityProbeKind) []ExecRequest { requests := []ExecRequest{ buildLegacyRawExecRequest(wrapPOSIXRawCommand(buildLegacyPOSIXIdentityProbeScript(kind))), } for _, executable := range []string{"powershell.exe", "pwsh.exe", "pwsh"} { requests = append(requests, buildLegacyRawExecRequest( wrapPowerShellRawCommand(executable, buildLegacyPowerShellIdentityProbeScript(kind)), )) } if kind == legacyIdentityProbeUser { requests = append(requests, buildLegacyRawExecRequest("cmd.exe /Q /D /C whoami")) } return requests } func buildLegacyRawExecRequest(command string) ExecRequest { return ExecRequest{ Command: command, ShellDialect: ExecShellDialectRaw, } } func wrapPOSIXRawCommand(script string) string { return "sh -lc " + shellSingleQuote(script) } func wrapPowerShellRawCommand(executable string, script string) string { return executable + " -NoLogo -NoProfile -NonInteractive -Command " + powerShellSingleQuote(script) } func buildLegacyPOSIXPathProbeScript(kind legacyPathProbeKind, filepath string) string { flag := "-e" switch kind { case legacyPathProbeFile: flag = "-f" case legacyPathProbeDirectory: flag = "-d" } return fmt.Sprintf("if [ %s -- %s ]; then printf '1\\n'; else printf '0\\n'; fi", flag, shellSingleQuote(filepath)) } func buildLegacyPowerShellPathProbeScript(kind legacyPathProbeKind, filepath string) string { condition := "Test-Path -LiteralPath " + powerShellSingleQuote(filepath) switch kind { case legacyPathProbeFile: condition += " -PathType Leaf" case legacyPathProbeDirectory: condition += " -PathType Container" } return "if (" + condition + ") { Write-Output '1' } else { Write-Output '0' }" } func buildLegacyCMDPathProbeCommand(kind legacyPathProbeKind, filepath string) string { parts := []string{ "setlocal DisableDelayedExpansion", `set "STARSSH_PATH=` + cmdEscapeForSetValue(filepath) + `"`, } switch kind { case legacyPathProbeExists: parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (echo 1) else echo 0`)) case legacyPathProbeFile: parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!" (if exist "!STARSSH_PATH!\NUL" (echo 0) else echo 1) else echo 0`)) case legacyPathProbeDirectory: parts = append(parts, wrapCMDCommand(`if exist "!STARSSH_PATH!\NUL" (echo 1) else echo 0`)) } return strings.Join(parts, " && ") } func buildLegacyPOSIXIdentityProbeScript(kind legacyIdentityProbeKind) string { switch kind { case legacyIdentityProbeUID: return "id -u" case legacyIdentityProbeGID: return "id -g" case legacyIdentityProbeUser: return "id -un" case legacyIdentityProbeGroup: return "id -gn" default: return "" } } func buildLegacyPowerShellIdentityProbeScript(kind legacyIdentityProbeKind) string { switch kind { case legacyIdentityProbeUID: return "[System.Security.Principal.WindowsIdentity]::GetCurrent().User.Value" case legacyIdentityProbeGID: return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { $group.Value }" case legacyIdentityProbeUser: return "$env:USERNAME" case legacyIdentityProbeGroup: return "$id = [System.Security.Principal.WindowsIdentity]::GetCurrent(); $group = $id.Groups | Select-Object -First 1; if ($group) { try { $group.Translate([System.Security.Principal.NTAccount]).Value } catch { $group.Value } }" default: return "" } }