package starssh import ( "bufio" "bytes" "errors" "fmt" "io" "os" "reflect" "sync" "unsafe" ) // TerminalInputSourceProvider lets wrapper readers expose a closer-friendly source reader. // Implementations that buffer data should return a source that already includes any prefetched bytes. type TerminalInputSourceProvider interface { TerminalInputSource() io.Reader } // TerminalInputCanceler lets wrapper readers expose an explicit cancellation hook. // It is useful for line editors or custom buffered readers that cannot safely expose a raw io.ReadCloser. type TerminalInputCanceler interface { TerminalInputCancel() error } // TerminalInputAdapter adapts wrapper readers into a cancelable terminal input source. // Reader is what TerminalSession consumes, Source is the closer-friendly underlying reader when available. type TerminalInputAdapter struct { Reader io.Reader Source io.Reader Cancel func() error } func (a TerminalInputAdapter) Read(p []byte) (int, error) { if a.Reader == nil { return 0, io.EOF } return a.Reader.Read(p) } func (a TerminalInputAdapter) TerminalInputSource() io.Reader { if a.Source != nil { return a.Source } return a.Reader } func (a TerminalInputAdapter) TerminalInputCancel() error { if a.Cancel != nil { return a.Cancel() } if closer, ok := a.Source.(io.Closer); ok && closer != nil { return closer.Close() } if closer, ok := a.Reader.(io.Closer); ok && closer != nil { return closer.Close() } return nil } func prepareTerminalInputReader(in io.Reader) (io.Reader, func(), bool, error) { if in == nil { return nil, func() {}, false, nil } var cancelOnce sync.Once wrapCancel := func(fn func()) func() { return func() { cancelOnce.Do(fn) } } if provider, ok := in.(TerminalInputSourceProvider); ok { source := provider.TerminalInputSource() if source == nil || sameReader(source, in) { return prepareDirectTerminalInputReader(in, wrapCancel) } prepared, cancel, cancelable, err := prepareTerminalInputReader(source) if err != nil { return nil, nil, false, err } if canceler, ok := in.(TerminalInputCanceler); ok { return prepared, wrapCancel(func() { cancel() _ = canceler.TerminalInputCancel() }), true, nil } return prepared, cancel, cancelable, nil } return prepareDirectTerminalInputReader(in, wrapCancel) } func prepareDirectTerminalInputReader(in io.Reader, wrapCancel func(func()) func()) (io.Reader, func(), bool, error) { if in == nil { return nil, func() {}, false, nil } switch typed := in.(type) { case *bufio.Reader: return prepareBufferedTerminalInputReader(typed) case *bufio.ReadWriter: if typed.Reader == nil { return in, func() {}, false, nil } return prepareBufferedTerminalInputReader(typed.Reader) } if canceler, ok := in.(TerminalInputCanceler); ok { return in, wrapCancel(func() { _ = canceler.TerminalInputCancel() }), true, nil } if file, ok := in.(*os.File); ok { dup, err := duplicateTerminalInputFile(file) if err != nil { return nil, nil, false, fmt.Errorf("duplicate terminal input: %w", err) } return dup, wrapCancel(func() { _ = dup.Close() }), true, nil } if closer, ok := in.(io.ReadCloser); ok { return closer, wrapCancel(func() { _ = closer.Close() }), true, nil } return in, func() {}, false, nil } func prepareBufferedTerminalInputReader(reader *bufio.Reader) (io.Reader, func(), bool, error) { if reader == nil { return nil, func() {}, false, nil } bufferedPrefix, err := snapshotBufferedPrefix(reader) if err != nil { return nil, nil, false, err } underlying := unwrapBufioReader(reader) if underlying == nil { if len(bufferedPrefix) == 0 { return reader, func() {}, false, nil } return io.MultiReader(bytes.NewReader(bufferedPrefix), reader), func() {}, false, nil } prepared, cancel, cancelable, err := prepareTerminalInputReader(underlying) if err != nil { return nil, nil, false, err } if len(bufferedPrefix) == 0 { return prepared, cancel, cancelable, nil } if prepared == nil { return bytes.NewReader(bufferedPrefix), cancel, cancelable, nil } return io.MultiReader(bytes.NewReader(bufferedPrefix), prepared), cancel, cancelable, nil } func snapshotBufferedPrefix(reader *bufio.Reader) ([]byte, error) { if reader == nil { return nil, nil } buffered := reader.Buffered() if buffered == 0 { return nil, nil } chunk, err := reader.Peek(buffered) if err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("peek terminal input buffer: %w", err) } prefix := append([]byte(nil), chunk...) if _, err := reader.Discard(len(prefix)); err != nil { return nil, fmt.Errorf("discard terminal input buffer: %w", err) } return prefix, nil } func unwrapBufioReader(reader *bufio.Reader) io.Reader { if reader == nil { return nil } value := reflect.ValueOf(reader) if value.Kind() != reflect.Pointer || value.IsNil() { return nil } field := value.Elem().FieldByName("rd") if !field.IsValid() { return nil } underlyingValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() underlying, ok := underlyingValue.Interface().(io.Reader) if !ok || underlying == nil || sameReader(underlying, reader) { return nil } return underlying } func sameReader(left io.Reader, right io.Reader) bool { if left == nil || right == nil { return false } leftValue := reflect.ValueOf(left) rightValue := reflect.ValueOf(right) if !leftValue.IsValid() || !rightValue.IsValid() { return false } if leftValue.Kind() != reflect.Pointer || rightValue.Kind() != reflect.Pointer { return false } return leftValue.Pointer() == rightValue.Pointer() }