226 lines
5.6 KiB
Go
226 lines
5.6 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"bytes"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"os"
|
||
|
|
"reflect"
|
||
|
|
"sync"
|
||
|
|
"unsafe"
|
||
|
|
)
|
||
|
|
|
||
|
|
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
|
||
|
|
// Implementations that buffer data should return a source that already includes any prefetched bytes.
|
||
|
|
type TerminalInputSourceProvider interface {
|
||
|
|
TerminalInputSource() io.Reader
|
||
|
|
}
|
||
|
|
|
||
|
|
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
|
||
|
|
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
|
||
|
|
type TerminalInputCanceler interface {
|
||
|
|
TerminalInputCancel() error
|
||
|
|
}
|
||
|
|
|
||
|
|
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
|
||
|
|
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
|
||
|
|
type TerminalInputAdapter struct {
|
||
|
|
Reader io.Reader
|
||
|
|
Source io.Reader
|
||
|
|
Cancel func() error
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
|
||
|
|
if a.Reader == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
return a.Reader.Read(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
|
||
|
|
if a.Source != nil {
|
||
|
|
return a.Source
|
||
|
|
}
|
||
|
|
return a.Reader
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a TerminalInputAdapter) TerminalInputCancel() error {
|
||
|
|
if a.Cancel != nil {
|
||
|
|
return a.Cancel()
|
||
|
|
}
|
||
|
|
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
|
||
|
|
return closer.Close()
|
||
|
|
}
|
||
|
|
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
|
||
|
|
return closer.Close()
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
|
||
|
|
if in == nil {
|
||
|
|
return nil, func() {}, false, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var cancelOnce sync.Once
|
||
|
|
wrapCancel := func(fn func()) func() {
|
||
|
|
return func() {
|
||
|
|
cancelOnce.Do(fn)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if provider, ok := in.(TerminalInputSourceProvider); ok {
|
||
|
|
source := provider.TerminalInputSource()
|
||
|
|
if source == nil || sameReader(source, in) {
|
||
|
|
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||
|
|
}
|
||
|
|
|
||
|
|
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, false, err
|
||
|
|
}
|
||
|
|
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||
|
|
return prepared, wrapCancel(func() {
|
||
|
|
cancel()
|
||
|
|
_ = canceler.TerminalInputCancel()
|
||
|
|
}), true, nil
|
||
|
|
}
|
||
|
|
return prepared, cancel, cancelable, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return prepareDirectTerminalInputReader(in, wrapCancel)
|
||
|
|
}
|
||
|
|
|
||
|
|
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
|
||
|
|
if in == nil {
|
||
|
|
return nil, func() {}, false, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
switch typed := in.(type) {
|
||
|
|
case *bufio.Reader:
|
||
|
|
return prepareBufferedTerminalInputReader(typed)
|
||
|
|
case *bufio.ReadWriter:
|
||
|
|
if typed.Reader == nil {
|
||
|
|
return in, func() {}, false, nil
|
||
|
|
}
|
||
|
|
return prepareBufferedTerminalInputReader(typed.Reader)
|
||
|
|
}
|
||
|
|
|
||
|
|
if canceler, ok := in.(TerminalInputCanceler); ok {
|
||
|
|
return in, wrapCancel(func() {
|
||
|
|
_ = canceler.TerminalInputCancel()
|
||
|
|
}), true, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if file, ok := in.(*os.File); ok {
|
||
|
|
dup, err := duplicateTerminalInputFile(file)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
|
||
|
|
}
|
||
|
|
return dup, wrapCancel(func() {
|
||
|
|
_ = dup.Close()
|
||
|
|
}), true, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if closer, ok := in.(io.ReadCloser); ok {
|
||
|
|
return closer, wrapCancel(func() {
|
||
|
|
_ = closer.Close()
|
||
|
|
}), true, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return in, func() {}, false, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
|
||
|
|
if reader == nil {
|
||
|
|
return nil, func() {}, false, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
bufferedPrefix, err := snapshotBufferedPrefix(reader)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, false, err
|
||
|
|
}
|
||
|
|
|
||
|
|
underlying := unwrapBufioReader(reader)
|
||
|
|
if underlying == nil {
|
||
|
|
if len(bufferedPrefix) == 0 {
|
||
|
|
return reader, func() {}, false, nil
|
||
|
|
}
|
||
|
|
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, false, err
|
||
|
|
}
|
||
|
|
if len(bufferedPrefix) == 0 {
|
||
|
|
return prepared, cancel, cancelable, nil
|
||
|
|
}
|
||
|
|
if prepared == nil {
|
||
|
|
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
|
||
|
|
}
|
||
|
|
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
|
||
|
|
if reader == nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
buffered := reader.Buffered()
|
||
|
|
if buffered == 0 {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
chunk, err := reader.Peek(buffered)
|
||
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
||
|
|
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
|
||
|
|
}
|
||
|
|
prefix := append([]byte(nil), chunk...)
|
||
|
|
if _, err := reader.Discard(len(prefix)); err != nil {
|
||
|
|
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
|
||
|
|
}
|
||
|
|
return prefix, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
|
||
|
|
if reader == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
value := reflect.ValueOf(reader)
|
||
|
|
if value.Kind() != reflect.Pointer || value.IsNil() {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
field := value.Elem().FieldByName("rd")
|
||
|
|
if !field.IsValid() {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
|
||
|
|
underlying, ok := underlyingValue.Interface().(io.Reader)
|
||
|
|
if !ok || underlying == nil || sameReader(underlying, reader) {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return underlying
|
||
|
|
}
|
||
|
|
|
||
|
|
func sameReader(left io.Reader, right io.Reader) bool {
|
||
|
|
if left == nil || right == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
leftValue := reflect.ValueOf(left)
|
||
|
|
rightValue := reflect.ValueOf(right)
|
||
|
|
if !leftValue.IsValid() || !rightValue.IsValid() {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return leftValue.Pointer() == rightValue.Pointer()
|
||
|
|
}
|