Files
starssh/terminal_input.go
T

226 lines
5.6 KiB
Go
Raw Permalink Normal View History

package starssh
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"reflect"
"sync"
"unsafe"
)
// TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader.
// Implementations that buffer data should return a source that already includes any prefetched bytes.
type TerminalInputSourceProvider interface {
TerminalInputSource() io.Reader
}
// TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook.
// It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser.
type TerminalInputCanceler interface {
TerminalInputCancel() error
}
// TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source.
// Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available.
type TerminalInputAdapter struct {
Reader io.Reader
Source io.Reader
Cancel func() error
}
func (a TerminalInputAdapter) Read(p []byte) (int, error) {
if a.Reader == nil {
return 0, io.EOF
}
return a.Reader.Read(p)
}
func (a TerminalInputAdapter) TerminalInputSource() io.Reader {
if a.Source != nil {
return a.Source
}
return a.Reader
}
func (a TerminalInputAdapter) TerminalInputCancel() error {
if a.Cancel != nil {
return a.Cancel()
}
if closer, ok := a.Source.(io.Closer); ok && closer != nil {
return closer.Close()
}
if closer, ok := a.Reader.(io.Closer); ok && closer != nil {
return closer.Close()
}
return nil
}
func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
var cancelOnce sync.Once
wrapCancel := func(fn func()) func() {
return func() {
cancelOnce.Do(fn)
}
}
if provider, ok := in.(TerminalInputSourceProvider); ok {
source := provider.TerminalInputSource()
if source == nil || sameReader(source, in) {
return prepareDirectTerminalInputReader(in, wrapCancel)
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(source)
if err != nil {
return nil, nil, false, err
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return prepared, wrapCancel(func() {
cancel()
_ = canceler.TerminalInputCancel()
}), true, nil
}
return prepared, cancel, cancelable, nil
}
return prepareDirectTerminalInputReader(in, wrapCancel)
}
func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) {
if in == nil {
return nil, func() {}, false, nil
}
switch typed := in.(type) {
case *bufio.Reader:
return prepareBufferedTerminalInputReader(typed)
case *bufio.ReadWriter:
if typed.Reader == nil {
return in, func() {}, false, nil
}
return prepareBufferedTerminalInputReader(typed.Reader)
}
if canceler, ok := in.(TerminalInputCanceler); ok {
return in, wrapCancel(func() {
_ = canceler.TerminalInputCancel()
}), true, nil
}
if file, ok := in.(*os.File); ok {
dup, err := duplicateTerminalInputFile(file)
if err != nil {
return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err)
}
return dup, wrapCancel(func() {
_ = dup.Close()
}), true, nil
}
if closer, ok := in.(io.ReadCloser); ok {
return closer, wrapCancel(func() {
_ = closer.Close()
}), true, nil
}
return in, func() {}, false, nil
}
func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) {
if reader == nil {
return nil, func() {}, false, nil
}
bufferedPrefix, err := snapshotBufferedPrefix(reader)
if err != nil {
return nil, nil, false, err
}
underlying := unwrapBufioReader(reader)
if underlying == nil {
if len(bufferedPrefix) == 0 {
return reader, func() {}, false, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil
}
prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying)
if err != nil {
return nil, nil, false, err
}
if len(bufferedPrefix) == 0 {
return prepared, cancel, cancelable, nil
}
if prepared == nil {
return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil
}
return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil
}
func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
buffered := reader.Buffered()
if buffered == 0 {
return nil, nil
}
chunk, err := reader.Peek(buffered)
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("peek terminal input buffer: %w", err)
}
prefix := append([]byte(nil), chunk...)
if _, err := reader.Discard(len(prefix)); err != nil {
return nil, fmt.Errorf("discard terminal input buffer: %w", err)
}
return prefix, nil
}
func unwrapBufioReader(reader *bufio.Reader) io.Reader {
if reader == nil {
return nil
}
value := reflect.ValueOf(reader)
if value.Kind() != reflect.Pointer || value.IsNil() {
return nil
}
field := value.Elem().FieldByName("rd")
if !field.IsValid() {
return nil
}
underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
underlying, ok := underlyingValue.Interface().(io.Reader)
if !ok || underlying == nil || sameReader(underlying, reader) {
return nil
}
return underlying
}
func sameReader(left io.Reader, right io.Reader) bool {
if left == nil || right == nil {
return false
}
leftValue := reflect.ValueOf(left)
rightValue := reflect.ValueOf(right)
if !leftValue.IsValid() || !rightValue.IsValid() {
return false
}
if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer {
return false
}
return leftValue.Pointer() == rightValue.Pointer()
}