432 lines
8.7 KiB
Go
432 lines
8.7 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"sync"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
)
|
||
|
|
|
||
|
|
func (s *StarSSH) NewTerminal(config *TerminalConfig) (*TerminalSession, error) {
|
||
|
|
session, err := s.NewPTYSession(config)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
stdin, err := session.StdinPipe()
|
||
|
|
if err != nil {
|
||
|
|
_ = session.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
stdout, err := session.StdoutPipe()
|
||
|
|
if err != nil {
|
||
|
|
_ = session.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
stderr, err := session.StderrPipe()
|
||
|
|
if err != nil {
|
||
|
|
_ = session.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := session.Shell(); err != nil {
|
||
|
|
_ = session.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return &TerminalSession{
|
||
|
|
Session: session,
|
||
|
|
stdin: stdin,
|
||
|
|
stdout: stdout,
|
||
|
|
stderr: stderr,
|
||
|
|
runDone: make(chan struct{}),
|
||
|
|
waitDone: make(chan struct{}),
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) AttachIO(stdin io.Reader, stdout io.Writer, stderr io.Writer) {
|
||
|
|
if t == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
t.attachMu.Lock()
|
||
|
|
defer t.attachMu.Unlock()
|
||
|
|
t.in = stdin
|
||
|
|
t.out = stdout
|
||
|
|
t.errOut = stderr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) StdinWriter() io.Writer {
|
||
|
|
if t == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return t.stdin
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) StdoutReader() io.Reader {
|
||
|
|
if t == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return t.stdout
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) StderrReader() io.Reader {
|
||
|
|
if t == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return t.stderr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Write(data []byte) (int, error) {
|
||
|
|
if t == nil || t.stdin == nil {
|
||
|
|
return 0, errors.New("terminal stdin is not initialized")
|
||
|
|
}
|
||
|
|
return t.stdin.Write(data)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) SendControl(control TerminalControl) error {
|
||
|
|
_, err := t.Write([]byte{byte(control)})
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Interrupt() error {
|
||
|
|
return t.SendControl(TerminalControlInterrupt)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Signal(sig ssh.Signal) error {
|
||
|
|
if t == nil || t.Session == nil {
|
||
|
|
return errors.New("terminal session is not initialized")
|
||
|
|
}
|
||
|
|
if sig == "" {
|
||
|
|
return errors.New("signal is empty")
|
||
|
|
}
|
||
|
|
return t.Session.Signal(sig)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Run(ctx context.Context) error {
|
||
|
|
if t == nil {
|
||
|
|
return errors.New("terminal session is nil")
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
|
||
|
|
t.runOnce.Do(func() {
|
||
|
|
t.runErr = t.run(ctx)
|
||
|
|
close(t.runDone)
|
||
|
|
})
|
||
|
|
|
||
|
|
<-t.runDone
|
||
|
|
return t.runErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) run(ctx context.Context) error {
|
||
|
|
if t.Session == nil {
|
||
|
|
return errors.New("terminal session is not initialized")
|
||
|
|
}
|
||
|
|
|
||
|
|
t.attachMu.RLock()
|
||
|
|
in := t.in
|
||
|
|
out := t.out
|
||
|
|
errOut := t.errOut
|
||
|
|
t.attachMu.RUnlock()
|
||
|
|
if out == nil {
|
||
|
|
out = io.Discard
|
||
|
|
}
|
||
|
|
if errOut == nil {
|
||
|
|
errOut = out
|
||
|
|
}
|
||
|
|
|
||
|
|
inputReader, cancelInput, inputCancelable, err := prepareTerminalInputReader(in)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer cancelInput()
|
||
|
|
|
||
|
|
var copyWG sync.WaitGroup
|
||
|
|
doneCopy := make(chan struct{})
|
||
|
|
copyWG.Add(2)
|
||
|
|
go func() {
|
||
|
|
defer copyWG.Done()
|
||
|
|
if t.stdout != nil {
|
||
|
|
_, _ = io.Copy(out, t.stdout)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
go func() {
|
||
|
|
defer copyWG.Done()
|
||
|
|
if t.stderr != nil {
|
||
|
|
_, _ = io.Copy(errOut, t.stderr)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
go func() {
|
||
|
|
copyWG.Wait()
|
||
|
|
close(doneCopy)
|
||
|
|
}()
|
||
|
|
|
||
|
|
var doneInput chan struct{}
|
||
|
|
if inputReader != nil && t.stdin != nil {
|
||
|
|
doneInput = make(chan struct{})
|
||
|
|
go func() {
|
||
|
|
defer close(doneInput)
|
||
|
|
_, _ = io.Copy(t.stdin, inputReader)
|
||
|
|
_ = t.stdin.Close()
|
||
|
|
}()
|
||
|
|
}
|
||
|
|
|
||
|
|
waitInputPump := func() {
|
||
|
|
if doneInput == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
select {
|
||
|
|
case <-doneInput:
|
||
|
|
return
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
|
||
|
|
if inputCancelable {
|
||
|
|
cancelInput()
|
||
|
|
<-doneInput
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type waitResult struct {
|
||
|
|
info TerminalExitInfo
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
|
||
|
|
waitCh := make(chan waitResult, 1)
|
||
|
|
go func() {
|
||
|
|
info, err := t.WaitResult()
|
||
|
|
waitCh <- waitResult{info: info, err: err}
|
||
|
|
}()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case result := <-waitCh:
|
||
|
|
waitInputPump()
|
||
|
|
<-doneCopy
|
||
|
|
if result.err != nil {
|
||
|
|
return result.err
|
||
|
|
}
|
||
|
|
return result.info.CommandError()
|
||
|
|
case <-ctx.Done():
|
||
|
|
t.markCloseReason(terminalCloseReasonFromErr(ctx.Err()), ctx.Err())
|
||
|
|
cancelInput()
|
||
|
|
_ = t.Close()
|
||
|
|
<-waitCh
|
||
|
|
waitInputPump()
|
||
|
|
<-doneCopy
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Wait() error {
|
||
|
|
info, err := t.WaitResult()
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return info.CommandError()
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) WaitResult() (TerminalExitInfo, error) {
|
||
|
|
waitErr := t.waitRaw()
|
||
|
|
info, closeErr := t.snapshotExitState()
|
||
|
|
if closeErr != nil {
|
||
|
|
return info, closeErr
|
||
|
|
}
|
||
|
|
if waitErr == nil {
|
||
|
|
return info, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var exitErr *ssh.ExitError
|
||
|
|
if errors.As(waitErr, &exitErr) {
|
||
|
|
return info, nil
|
||
|
|
}
|
||
|
|
if normalizeAlreadyClosedError(waitErr) == nil || info.Reason == TerminalCloseReasonClosed {
|
||
|
|
return info, nil
|
||
|
|
}
|
||
|
|
return info, waitErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) ExitInfo() TerminalExitInfo {
|
||
|
|
if t == nil {
|
||
|
|
return TerminalExitInfo{}
|
||
|
|
}
|
||
|
|
|
||
|
|
t.stateMu.RLock()
|
||
|
|
defer t.stateMu.RUnlock()
|
||
|
|
return t.exitInfo
|
||
|
|
}
|
||
|
|
|
||
|
|
func (info TerminalExitInfo) Success() bool {
|
||
|
|
return info.Reason == TerminalCloseReasonExit && info.ExitCode == 0 && info.ExitSignal == ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func (info TerminalExitInfo) CommandError() error {
|
||
|
|
if info.Reason != TerminalCloseReasonExit && info.Reason != TerminalCloseReasonSignal {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if info.ExitCode == 0 && info.ExitSignal == "" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return &ExecExitError{
|
||
|
|
Status: info.ExitCode,
|
||
|
|
Signal: info.ExitSignal,
|
||
|
|
Message: info.ExitMessage,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Resize(columns int, rows int) error {
|
||
|
|
if t == nil || t.Session == nil {
|
||
|
|
return errors.New("terminal session is not initialized")
|
||
|
|
}
|
||
|
|
if columns <= 0 || rows <= 0 {
|
||
|
|
return errors.New("columns and rows must be > 0")
|
||
|
|
}
|
||
|
|
return t.Session.WindowChange(rows, columns)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) Close() error {
|
||
|
|
if t == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeErr error
|
||
|
|
t.closeOnce.Do(func() {
|
||
|
|
if t.stdin != nil {
|
||
|
|
_ = t.stdin.Close()
|
||
|
|
}
|
||
|
|
if t.Session != nil {
|
||
|
|
closeErr = normalizeAlreadyClosedError(t.Session.Close())
|
||
|
|
}
|
||
|
|
})
|
||
|
|
if closeErr != nil {
|
||
|
|
t.markCloseReason(TerminalCloseReasonTransportError, closeErr)
|
||
|
|
}
|
||
|
|
return closeErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) waitRaw() error {
|
||
|
|
if t == nil || t.Session == nil {
|
||
|
|
return errors.New("terminal session is not initialized")
|
||
|
|
}
|
||
|
|
|
||
|
|
t.waitOnce.Do(func() {
|
||
|
|
go func() {
|
||
|
|
waitErr := t.Session.Wait()
|
||
|
|
t.setWaitResult(waitErr)
|
||
|
|
close(t.waitDone)
|
||
|
|
}()
|
||
|
|
})
|
||
|
|
|
||
|
|
<-t.waitDone
|
||
|
|
|
||
|
|
t.stateMu.RLock()
|
||
|
|
defer t.stateMu.RUnlock()
|
||
|
|
return t.waitErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) setWaitResult(waitErr error) {
|
||
|
|
if t == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
t.stateMu.Lock()
|
||
|
|
defer t.stateMu.Unlock()
|
||
|
|
t.waitErr = waitErr
|
||
|
|
t.exitInfo = buildTerminalExitInfo(waitErr, t.closeReason)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) markCloseReason(reason TerminalCloseReason, err error) {
|
||
|
|
if t == nil || reason == TerminalCloseReasonUnknown {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
t.stateMu.Lock()
|
||
|
|
defer t.stateMu.Unlock()
|
||
|
|
|
||
|
|
if terminalCloseReasonPriority(reason) >= terminalCloseReasonPriority(t.closeReason) {
|
||
|
|
t.closeReason = reason
|
||
|
|
}
|
||
|
|
if err != nil && t.closeErr == nil {
|
||
|
|
t.closeErr = err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *TerminalSession) snapshotExitState() (TerminalExitInfo, error) {
|
||
|
|
if t == nil {
|
||
|
|
return TerminalExitInfo{}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
t.stateMu.RLock()
|
||
|
|
defer t.stateMu.RUnlock()
|
||
|
|
return t.exitInfo, t.closeErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildTerminalExitInfo(waitErr error, overrideReason TerminalCloseReason) TerminalExitInfo {
|
||
|
|
info := TerminalExitInfo{}
|
||
|
|
|
||
|
|
if waitErr == nil {
|
||
|
|
info.Reason = TerminalCloseReasonExit
|
||
|
|
} else {
|
||
|
|
var exitErr *ssh.ExitError
|
||
|
|
switch {
|
||
|
|
case errors.As(waitErr, &exitErr):
|
||
|
|
info.ExitCode = exitErr.ExitStatus()
|
||
|
|
info.ExitSignal = exitErr.Signal()
|
||
|
|
info.ExitMessage = exitErr.Msg()
|
||
|
|
if info.ExitSignal != "" {
|
||
|
|
info.Reason = TerminalCloseReasonSignal
|
||
|
|
} else {
|
||
|
|
info.Reason = TerminalCloseReasonExit
|
||
|
|
}
|
||
|
|
case normalizeAlreadyClosedError(waitErr) == nil:
|
||
|
|
info.Reason = TerminalCloseReasonClosed
|
||
|
|
case errors.Is(waitErr, context.Canceled):
|
||
|
|
info.Reason = TerminalCloseReasonContextCanceled
|
||
|
|
case errors.Is(waitErr, context.DeadlineExceeded):
|
||
|
|
info.Reason = TerminalCloseReasonDeadlineExceeded
|
||
|
|
default:
|
||
|
|
info.Reason = TerminalCloseReasonTransportError
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if overrideReason != TerminalCloseReasonUnknown &&
|
||
|
|
terminalCloseReasonPriority(overrideReason) >= terminalCloseReasonPriority(info.Reason) {
|
||
|
|
info.Reason = overrideReason
|
||
|
|
}
|
||
|
|
|
||
|
|
return info
|
||
|
|
}
|
||
|
|
|
||
|
|
func terminalCloseReasonFromErr(err error) TerminalCloseReason {
|
||
|
|
switch {
|
||
|
|
case errors.Is(err, context.DeadlineExceeded):
|
||
|
|
return TerminalCloseReasonDeadlineExceeded
|
||
|
|
case errors.Is(err, context.Canceled):
|
||
|
|
return TerminalCloseReasonContextCanceled
|
||
|
|
case err != nil:
|
||
|
|
return TerminalCloseReasonTransportError
|
||
|
|
default:
|
||
|
|
return TerminalCloseReasonUnknown
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func terminalCloseReasonPriority(reason TerminalCloseReason) int {
|
||
|
|
switch reason {
|
||
|
|
case TerminalCloseReasonContextCanceled, TerminalCloseReasonDeadlineExceeded:
|
||
|
|
return 30
|
||
|
|
case TerminalCloseReasonTransportError:
|
||
|
|
return 20
|
||
|
|
case TerminalCloseReasonClosed:
|
||
|
|
return 10
|
||
|
|
case TerminalCloseReasonSignal, TerminalCloseReasonExit:
|
||
|
|
return 5
|
||
|
|
default:
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
}
|