package starssh import ( "bufio" "context" "errors" "fmt" "io" "strings" "time" ) var errStarShellPOSIXOnly = errors.New("legacy StarShell only supports POSIX-compatible shells") type ShellRequest struct { Command string Timeout time.Duration Keyword string UseWaitDefault *bool } // NewShell creates the legacy prompt-driven POSIX shell helper. // For raw interactive terminal flows, prefer NewTerminal. func (s *StarSSH) NewShell() (shell *StarShell, err error) { shell = &StarShell{ UseWaitDefault: true, WaitTimeout: defaultShellWaitTimeout, isecho: true, iscolor: true, promptToken: defaultShellPromptToken, } shell.Session, err = s.NewPTYSession(nil) if err != nil { return nil, err } shell.in, err = shell.Session.StdinPipe() if err != nil { _ = shell.Session.Close() return nil, err } stdout, err := shell.Session.StdoutPipe() if err != nil { _ = shell.Session.Close() return nil, err } shell.out = bufio.NewReader(stdout) stderr, err := shell.Session.StderrPipe() if err != nil { _ = shell.Session.Close() return nil, err } shell.er = bufio.NewReader(stderr) if err := shell.Session.Shell(); err != nil { _ = shell.Session.Close() return nil, err } go shell.watchSession() shell.gohub() if err := shell.configurePrompt(context.Background()); err != nil { _ = shell.Session.Close() return nil, err } shell.Clear() return shell, nil } func (s *StarShell) configurePrompt(ctx context.Context) error { if s == nil { return errors.New("shell is nil") } if ctx == nil { ctx = context.Background() } if _, hasDeadline := ctx.Deadline(); !hasDeadline { timeoutCtx, cancel := context.WithTimeout(ctx, defaultShellSetupTimeout) defer cancel() ctx = timeoutCtx } prompt := s.promptToken + " " setupCommands := []string{ "unset PROMPT_COMMAND >/dev/null 2>&1 || true", fmt.Sprintf("export PS1=%s PS2='' PROMPT=%s RPROMPT='' >/dev/null 2>&1 || true", shellSingleQuote(prompt), shellSingleQuote(prompt)), } s.Clear() for _, cmd := range setupCommands { if err := s.WriteCommand(cmd); err != nil { return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err) } } probeToken := "__STARSSH_POSIX_READY__" + newNonce(6) if err := s.WriteCommand(fmt.Sprintf("printf '%%s\\n' %s", shellSingleQuote(probeToken))); err != nil { return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, err) } ticker := time.NewTicker(defaultShellPollInterval) defer ticker.Stop() for { outRaw, errRaw, runErr := s.readState() if runErr != nil { return fmt.Errorf("%w: %v", errStarShellPOSIXOnly, runErr) } outs := normalizeShellOutput(stripControlSequences(string(outRaw))) if strings.Contains(outs, probeToken) { s.Clear() return nil } errs := normalizeShellOutput(stripControlSequences(string(errRaw))) if looksLikeNonPOSIXShellError(errs) { return fmt.Errorf("%w: %s", errStarShellPOSIXOnly, errs) } select { case <-ctx.Done(): if errors.Is(ctx.Err(), context.DeadlineExceeded) { return fmt.Errorf("%w: prompt bootstrap timed out", errStarShellPOSIXOnly) } return ctx.Err() case <-ticker.C: } } } func (s *StarShell) watchSession() { if s == nil || s.Session == nil { return } if err := s.Session.Wait(); err != nil && !errors.Is(err, io.EOF) { s.setError(err) } } func (s *StarShell) Close() error { if s == nil { return nil } var closeErr error s.closeOnce.Do(func() { if s.Session == nil { return } closeErr = s.Session.Close() }) return closeErr } func (s *StarShell) SwitchNoColor(run bool) { s.rw.Lock() defer s.rw.Unlock() s.iscolor = run } func (s *StarShell) SwitchEcho(run bool) { s.rw.Lock() defer s.rw.Unlock() s.isecho = run } func (s *StarShell) TrimColor(str string) string { s.rw.RLock() shouldTrim := s.iscolor s.rw.RUnlock() if shouldTrim { return SedColor(str) } return str } /* 本函数控制是否在本地屏幕上打印远程Shell的输出内容[true|false] */ func (s *StarShell) SwitchPrint(run bool) { s.rw.Lock() defer s.rw.Unlock() s.isprint = run } /* 本函数控制是否立即处理远程Shell输出每一行内容[true|false] */ func (s *StarShell) SwitchFunc(run bool) { s.rw.Lock() defer s.rw.Unlock() s.isfuncs = run } func (s *StarShell) SetFunc(funcs func(string)) { s.rw.Lock() defer s.rw.Unlock() s.funcs = funcs } func (s *StarShell) Clear() { s.rw.Lock() defer s.rw.Unlock() s.outbyte = []byte{} s.errbyte = []byte{} } func (s *StarShell) ShellClear(cmd string, sleep int) (string, string, error) { s.Clear() defer s.Clear() return s.Shell(cmd, sleep) } func (s *StarShell) Shell(cmd string, sleep int) (string, string, error) { s.commandMu.Lock() defer s.commandMu.Unlock() if err := s.WriteCommand(cmd); err != nil { return "", "", err } outRaw, errRaw, runErr := s.GetResult(sleep) if runErr != nil { return "", "", runErr } outText := s.TrimColor(strings.TrimSpace(string(outRaw))) s.rw.RLock() echoEnabled := s.isecho s.rw.RUnlock() if echoEnabled { outText = stripCommandEchoFromOutput(outText, cmd) } return strings.TrimSpace(outText), s.TrimColor(strings.TrimSpace(string(errRaw))), nil } func (s *StarShell) ShellWait(cmd string) (string, string, error) { result, err := s.Run(context.Background(), ShellRequest{ Command: cmd, }) if err != nil { return "", "", err } return strings.TrimSpace(result.StdoutString()), strings.TrimSpace(result.StderrString()), result.CommandError() } func (s *StarShell) RunString(ctx context.Context, command string) (*ExecResult, error) { return s.Run(ctx, ShellRequest{ Command: command, }) } func (s *StarShell) Run(ctx context.Context, req ShellRequest) (*ExecResult, error) { if s == nil { return nil, errors.New("shell is nil") } if ctx == nil { ctx = context.Background() } s.commandMu.Lock() defer s.commandMu.Unlock() if strings.TrimSpace(req.Command) == "" { return nil, errors.New("command is empty") } s.rw.RLock() useDefault := s.UseWaitDefault keyword := s.Keyword waitTimeout := s.WaitTimeout promptToken := s.promptToken s.rw.RUnlock() if req.UseWaitDefault != nil { useDefault = *req.UseWaitDefault } if strings.TrimSpace(req.Keyword) != "" { keyword = req.Keyword } if req.Timeout > 0 { waitTimeout = req.Timeout } if !useDefault && keyword == "" { return nil, errors.New("ShellRun requires UseWaitDefault=true or Keyword set") } if waitTimeout <= 0 { waitTimeout = defaultShellWaitTimeout } if _, hasDeadline := ctx.Deadline(); !hasDeadline { timeoutCtx, cancel := context.WithTimeout(ctx, waitTimeout) defer cancel() ctx = timeoutCtx } startAt := time.Now() s.Clear() defer s.Clear() beginToken, endToken := newCommandTokens() markerCmd := fmt.Sprintf("__STARSSH_RC=$?; printf '%s:%%s\\n' \"$__STARSSH_RC\"", endToken) if err := s.WriteCommand(fmt.Sprintf("printf '%s\\n'", beginToken)); err != nil { return nil, err } if err := s.WriteCommand(req.Command); err != nil { return nil, err } if err := s.WriteCommand(markerCmd); err != nil { return nil, err } var ( outc string errc string exitCode int done bool ) for { select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(defaultShellPollInterval): } outRaw, errRaw, runErr := s.readState() if runErr != nil { return nil, runErr } outs := normalizeShellOutput(stripControlSequences(string(outRaw))) errs := normalizeShellOutput(stripControlSequences(string(errRaw))) s.rw.RLock() useDefault = s.UseWaitDefault keyword = s.Keyword s.rw.RUnlock() if useDefault { segment, rc, found, parseErr := extractCommandSegment(outs, beginToken, endToken) if parseErr != nil { return nil, parseErr } if found { outc = segment errc = errs exitCode = rc done = true break } } if keyword != "" { if strings.Contains(outs, keyword) || strings.Contains(errs, keyword) { outc = outs errc = errs done = true break } } } if !done { return nil, errors.New("failed to collect shell result") } outc = collectLinesForCommandOutput(outc, promptToken, beginToken, endToken) errc = collectLinesForCommandOutput(errc, promptToken, beginToken, endToken) outc = stripCommandEchoFromOutput(outc, req.Command) stdoutText := strings.TrimSpace(outc) stderrText := strings.TrimSpace(errc) result := &ExecResult{ Command: req.Command, Stdout: []byte(stdoutText), Stderr: []byte(stderrText), Combined: combineCommandOutput(stdoutText, stderrText), Duration: time.Since(startAt), } if useDefault { result.ExitCode = exitCode } return result, nil } func extractCommandSegment(stdout string, beginToken string, endToken string) (string, int, bool, error) { lines := strings.Split(stdout, "\n") beginLine := -1 for i, line := range lines { if strings.TrimSpace(line) == beginToken { beginLine = i break } } if beginLine < 0 { return "", 0, false, nil } segment := strings.Join(lines[beginLine+1:], "\n") before, rc, found, err := splitByEndToken(segment, endToken) if err != nil { return "", 0, false, err } if !found { return "", 0, false, nil } return strings.TrimSpace(before), rc, true, nil } func (s *StarShell) GetResult(sleep int) ([]byte, []byte, error) { if sleep > 0 { time.Sleep(time.Millisecond * time.Duration(sleep)) } return s.readState() } func (s *StarShell) WriteCommand(cmd string) error { return s.Write([]byte(cmd + "\n")) } func (s *StarShell) Write(bstr []byte) error { if s == nil { return errors.New("shell is nil") } if s.in == nil { return errors.New("shell stdin is not initialized") } if _, _, runErr := s.readState(); runErr != nil { return runErr } s.writeMu.Lock() defer s.writeMu.Unlock() _, err := s.in.Write(bstr) return err } func (s *StarShell) gohub() { if s.er != nil { go s.streamPump(s.er, true) } if s.out != nil { go s.streamPump(s.out, false) } } func (s *StarShell) streamPump(reader *bufio.Reader, isStderr bool) { var cache []byte for { read, err := reader.ReadByte() if err == io.EOF { return } if err != nil { s.setError(err) return } s.rw.Lock() if isStderr { s.errbyte = append(s.errbyte, read) } else { s.outbyte = append(s.outbyte, read) } printEnabled := s.isprint funcEnabled := s.isfuncs && s.funcs != nil lineHandler := s.funcs trimColor := s.iscolor s.rw.Unlock() if printEnabled { fmt.Print(string([]byte{read})) } cache = append(cache, read) if read == '\n' { if funcEnabled { line := strings.TrimSpace(string(cache)) if trimColor { line = SedColor(line) } go lineHandler(line) } cache = cache[:0] } } } func (s *StarShell) setError(err error) { if err == nil || errors.Is(err, io.EOF) { return } s.rw.Lock() defer s.rw.Unlock() if s.errors == nil { s.errors = err } } func (s *StarShell) readState() ([]byte, []byte, error) { s.rw.RLock() defer s.rw.RUnlock() outCopy := make([]byte, len(s.outbyte)) copy(outCopy, s.outbyte) errCopy := make([]byte, len(s.errbyte)) copy(errCopy, s.errbyte) if s.errors != nil { return outCopy, errCopy, s.errors } return outCopy, errCopy, nil } func stripCommandEchoFromOutput(output string, cmd string) string { if output == "" || cmd == "" { return strings.TrimSpace(output) } lines := strings.Split(output, "\n") cmdLines := strings.Split(cmd, "\n") for _, cmdLine := range cmdLines { trimmedCmd := strings.TrimSpace(cmdLine) if trimmedCmd == "" { continue } for i, line := range lines { if strings.TrimSpace(line) == trimmedCmd { lines = append(lines[:i], lines[i+1:]...) break } } } return strings.TrimSpace(strings.Join(lines, "\n")) } func looksLikeNonPOSIXShellError(output string) bool { if strings.TrimSpace(output) == "" { return false } lower := strings.ToLower(output) indicators := []string{ "is not recognized as an internal or external command", "the term ", "command not found", "unknown command", "not found", "not recognized", } for _, indicator := range indicators { if strings.Contains(lower, indicator) { return true } } return false } func (s *StarShell) GetUid() string { res, _, _ := s.ShellWait("id -u") return strings.TrimSpace(res) } func (s *StarShell) GetGid() string { res, _, _ := s.ShellWait("id -g") return strings.TrimSpace(res) } func (s *StarShell) GetUser() string { res, _, _ := s.ShellWait("id -un") return strings.TrimSpace(res) } func (s *StarShell) GetGroup() string { res, _, _ := s.ShellWait("id -gn") return strings.TrimSpace(res) }