starssh/shell.go

593 lines
12 KiB
Go
Raw Permalink Normal View History

refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
package starssh
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strings"
"time"
)
var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells")
type ShellRequest struct {
Command string
Timeout time.Duration
Keyword string
UseWaitDefault *bool
}
// NewShell creates the legacy prompt-driven POSIX shell helper.
// For raw interactive terminal flows, prefer NewTerminal.
func (s *StarSSH) NewShell() (shell *StarShell, err error) {
shell = &StarShell{
UseWaitDefault: true,
WaitTimeout: defaultShellWaitTimeout,
isecho: true,
iscolor: true,
promptToken: defaultShellPromptToken,
}
shell.Session, err = s.NewPTYSession(nil)
if err != nil {
return nil, err
}
shell.in, err = shell.Session.StdinPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
stdout, err := shell.Session.StdoutPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.out = bufio.NewReader(stdout)
stderr, err := shell.Session.StderrPipe()
if err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.er = bufio.NewReader(stderr)
if err := shell.Session.Shell(); err != nil {
_ = shell.Session.Close()
return nil, err
}
go shell.watchSession()
shell.gohub()
if err := shell.configurePrompt(context.Background()); err != nil {
_ = shell.Session.Close()
return nil, err
}
shell.Clear()
return shell, nil
}
func (s *StarShell) configurePrompt(ctx context.Context) error {
if s == nil {
return errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout)
defer cancel()
ctx = timeoutCtx
}
prompt := s.promptToken + " "
setupCommands := []string{
"unset PROMPT_COMMAND >/dev/null 2>&1 || true",
fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)),
}
s.Clear()
for _, cmd := range setupCommands {
if err := s.WriteCommand(cmd); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
}
probeToken := "__STARSSH_POSIX_READY__" + newNonce(6)
if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err)
}
ticker := time.NewTicker(defaultShellPollInterval)
defer ticker.Stop()
for {
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr)
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
if strings.Contains(outs, probeToken) {
s.Clear()
return nil
}
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
if looksLikeNonPOSIXShellError(errs) {
return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs)
}
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly)
}
return ctx.Err()
case <-ticker.C:
}
}
}
func (s *StarShell) watchSession() {
if s == nil || s.Session == nil {
return
}
if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) {
s.setError(err)
}
}
func (s *StarShell) Close() error {
if s == nil {
return nil
}
var closeErr error
s.closeOnce.Do(func() {
if s.Session == nil {
return
}
closeErr = s.Session.Close()
})
return closeErr
}
func (s *StarShell) SwitchNoColor(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.iscolor = run
}
func (s *StarShell) SwitchEcho(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isecho = run
}
func (s *StarShell) TrimColor(str string) string {
s.rw.RLock()
shouldTrim := s.iscolor
s.rw.RUnlock()
if shouldTrim {
return SedColor(str)
}
return str
}
/*
本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false]
*/
func (s *StarShell) SwitchPrint(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isprint = run
}
/*
本函数控制是否立即处理远程Shell输出每一行内容[true|false]
*/
func (s *StarShell) SwitchFunc(run bool) {
s.rw.Lock()
defer s.rw.Unlock()
s.isfuncs = run
}
func (s *StarShell) SetFunc(funcs func(string)) {
s.rw.Lock()
defer s.rw.Unlock()
s.funcs = funcs
}
func (s *StarShell) Clear() {
s.rw.Lock()
defer s.rw.Unlock()
s.outbyte = []byte{}
s.errbyte = []byte{}
}
func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) {
s.Clear()
defer s.Clear()
return s.Shell(cmd, sleep)
}
func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) {
s.commandMu.Lock()
defer s.commandMu.Unlock()
if err := s.WriteCommand(cmd); err != nil {
return "", "", err
}
outRaw, errRaw, runErr := s.GetResult(sleep)
if runErr != nil {
return "", "", runErr
}
outText := s.TrimColor(strings.TrimSpace(string(outRaw)))
s.rw.RLock()
echoEnabled := s.isecho
s.rw.RUnlock()
if echoEnabled {
outText = stripCommandEchoFromOutput(outText, cmd)
}
return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil
}
func (s *StarShell) ShellWait(cmd string) (string, string, error) {
result, err := s.Run(context.Background(), ShellRequest{
Command: cmd,
})
if err != nil {
return "", "", err
}
return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError()
}
func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) {
return s.Run(ctx, ShellRequest{
Command: command,
})
}
func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) {
if s == nil {
return nil, errors.New("shell is nil")
}
if ctx == nil {
ctx = context.Background()
}
s.commandMu.Lock()
defer s.commandMu.Unlock()
if strings.TrimSpace(req.Command) == "" {
return nil, errors.New("command is empty")
}
s.rw.RLock()
useDefault := s.UseWaitDefault
keyword := s.Keyword
waitTimeout := s.WaitTimeout
promptToken := s.promptToken
s.rw.RUnlock()
if req.UseWaitDefault != nil {
useDefault = *req.UseWaitDefault
}
if strings.TrimSpace(req.Keyword) != "" {
keyword = req.Keyword
}
if req.Timeout > 0 {
waitTimeout = req.Timeout
}
if !useDefault && keyword == "" {
return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set")
}
if waitTimeout <= 0 {
waitTimeout = defaultShellWaitTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
ctx = timeoutCtx
}
startAt := time.Now()
s.Clear()
defer s.Clear()
beginToken, endToken := newCommandTokens()
markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken)
if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil {
return nil, err
}
if err := s.WriteCommand(req.Command); err != nil {
return nil, err
}
if err := s.WriteCommand(markerCmd); err != nil {
return nil, err
}
var (
outc string
errc string
exitCode int
done bool
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(defaultShellPollInterval):
}
outRaw, errRaw, runErr := s.readState()
if runErr != nil {
return nil, runErr
}
outs := normalizeShellOutput(stripControlSequences(string(outRaw)))
errs := normalizeShellOutput(stripControlSequences(string(errRaw)))
s.rw.RLock()
useDefault = s.UseWaitDefault
keyword = s.Keyword
s.rw.RUnlock()
if useDefault {
segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken)
if parseErr != nil {
return nil, parseErr
}
if found {
outc = segment
errc = errs
exitCode = rc
done = true
break
}
}
if keyword != "" {
if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) {
outc = outs
errc = errs
done = true
break
}
}
}
if !done {
return nil, errors.New("failed to collect shell result")
}
outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken)
errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken)
outc = stripCommandEchoFromOutput(outc, req.Command)
stdoutText := strings.TrimSpace(outc)
stderrText := strings.TrimSpace(errc)
result := &ExecResult{
Command: req.Command,
Stdout: []byte(stdoutText),
Stderr: []byte(stderrText),
Combined: combineCommandOutput(stdoutText, stderrText),
Duration: time.Since(startAt),
}
if useDefault {
result.ExitCode = exitCode
}
return result, nil
}
func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) {
lines := strings.Split(stdout, "\n")
beginLine := -1
for i, line := range lines {
if strings.TrimSpace(line) == beginToken {
beginLine = i
break
}
}
if beginLine < 0 {
return "", 0, false, nil
}
segment := strings.Join(lines[beginLine+1:], "\n")
before, rc, found, err := splitByEndToken(segment, endToken)
if err != nil {
return "", 0, false, err
}
if !found {
return "", 0, false, nil
}
return strings.TrimSpace(before), rc, true, nil
}
func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) {
if sleep > 0 {
time.Sleep(time.Millisecond * time.Duration(sleep))
}
return s.readState()
}
func (s *StarShell) WriteCommand(cmd string) error {
return s.Write([]byte(cmd + "\n"))
}
func (s *StarShell) Write(bstr []byte) error {
if s == nil {
return errors.New("shell is nil")
}
if s.in == nil {
return errors.New("shell stdin is not initialized")
}
if _, _, runErr := s.readState(); runErr != nil {
return runErr
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.in.Write(bstr)
return err
}
func (s *StarShell) gohub() {
if s.er != nil {
go s.streamPump(s.er, true)
}
if s.out != nil {
go s.streamPump(s.out, false)
}
}
func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) {
var cache []byte
for {
read, err := reader.ReadByte()
if err == io.EOF {
return
}
if err != nil {
s.setError(err)
return
}
s.rw.Lock()
if isStderr {
s.errbyte = append(s.errbyte, read)
} else {
s.outbyte = append(s.outbyte, read)
}
printEnabled := s.isprint
funcEnabled := s.isfuncs && s.funcs != nil
lineHandler := s.funcs
trimColor := s.iscolor
s.rw.Unlock()
if printEnabled {
fmt.Print(string([]byte{read}))
}
cache = append(cache, read)
if read == '\n' {
if funcEnabled {
line := strings.TrimSpace(string(cache))
if trimColor {
line = SedColor(line)
}
go lineHandler(line)
}
cache = cache[:0]
}
}
}
func (s *StarShell) setError(err error) {
if err == nil || errors.Is(err, io.EOF) {
return
}
s.rw.Lock()
defer s.rw.Unlock()
if s.errors == nil {
s.errors = err
}
}
func (s *StarShell) readState() ([]byte, []byte, error) {
s.rw.RLock()
defer s.rw.RUnlock()
outCopy := make([]byte, len(s.outbyte))
copy(outCopy, s.outbyte)
errCopy := make([]byte, len(s.errbyte))
copy(errCopy, s.errbyte)
if s.errors != nil {
return outCopy, errCopy, s.errors
}
return outCopy, errCopy, nil
}
func stripCommandEchoFromOutput(output string, cmd string) string {
if output == "" || cmd == "" {
return strings.TrimSpace(output)
}
lines := strings.Split(output, "\n")
cmdLines := strings.Split(cmd, "\n")
for _, cmdLine := range cmdLines {
trimmedCmd := strings.TrimSpace(cmdLine)
if trimmedCmd == "" {
continue
}
for i, line := range lines {
if strings.TrimSpace(line) == trimmedCmd {
lines = append(lines[:i], lines[i+1:]...)
break
}
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
func looksLikeNonPOSIXShellError(output string) bool {
if strings.TrimSpace(output) == "" {
return false
}
lower := strings.ToLower(output)
indicators := []string{
"is not recognized as an internal or external command",
"the term ",
"command not found",
"unknown command",
"not found",
"not recognized",
}
for _, indicator := range indicators {
if strings.Contains(lower, indicator) {
return true
}
}
return false
}
func (s *StarShell) GetUid() string {
res, _, _ := s.ShellWait("id -u")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGid() string {
res, _, _ := s.ShellWait("id -g")
return strings.TrimSpace(res)
}
func (s *StarShell) GetUser() string {
res, _, _ := s.ShellWait("id -un")
return strings.TrimSpace(res)
}
func (s *StarShell) GetGroup() string {
res, _, _ := s.ShellWait("id -gn")
return strings.TrimSpace(res)
}