package notify import ( "context" "errors" "io" "sync" ) type StreamCopyOptions struct { BufferSize int CloseWrite bool CloseStream bool CloseWriter bool } type StreamBridgeOptions struct { BufferSize int ClosePeerOnEOF bool ResetOnCopyError bool } type StreamOpenCopyOptions struct { Open StreamOpenOptions Copy StreamCopyOptions } func CopyToStream(ctx context.Context, stream Stream, src io.Reader, opt StreamCopyOptions) (int64, error) { if stream == nil { return 0, io.ErrClosedPipe } if src == nil { return 0, io.ErrClosedPipe } bufSize := opt.BufferSize if bufSize <= 0 { bufSize = defaultFileChunkSize } buf := make([]byte, bufSize) reader := newContextReader(ctx, src) written, err := io.CopyBuffer(stream, reader, buf) if err == nil || err == io.EOF { if opt.CloseWrite { if closeErr := stream.CloseWrite(); closeErr != nil { return written, closeErr } } else if opt.CloseStream { if closeErr := stream.Close(); closeErr != nil { return written, closeErr } } } return written, normalizeStreamCopyError(err) } func CopyFromStream(ctx context.Context, dst io.Writer, stream Stream, opt StreamCopyOptions) (int64, error) { if stream == nil { return 0, io.ErrClosedPipe } if dst == nil { return 0, io.ErrClosedPipe } bufSize := opt.BufferSize if bufSize <= 0 { bufSize = defaultFileChunkSize } buf := make([]byte, bufSize) reader := newContextReader(ctx, stream) written, err := io.CopyBuffer(dst, reader, buf) if (err == nil || err == io.EOF) && opt.CloseWriter { if closer, ok := dst.(io.Closer); ok { if closeErr := closer.Close(); closeErr != nil { return written, closeErr } } } return written, normalizeStreamCopyError(err) } type contextReader struct { ctx context.Context src io.Reader } func newContextReader(ctx context.Context, src io.Reader) io.Reader { if ctx == nil || src == nil { return src } return &contextReader{ctx: ctx, src: src} } func (r *contextReader) Read(p []byte) (int, error) { if r == nil || r.src == nil { return 0, io.EOF } select { case <-r.ctx.Done(): return 0, r.ctx.Err() default: } return r.src.Read(p) } func normalizeStreamCopyError(err error) error { if err == io.EOF { return nil } return err } func BridgeStream(ctx context.Context, stream Stream, peer io.ReadWriteCloser, opt StreamBridgeOptions) error { if stream == nil || peer == nil { return io.ErrClosedPipe } if ctx == nil { ctx = context.Background() } bridgeCtx, cancel := context.WithCancel(ctx) defer cancel() var wg sync.WaitGroup errCh := make(chan error, 2) var abortOnce sync.Once var primaryErr error abortBridge := func(err error) { abortOnce.Do(func() { primaryErr = err cancel() _ = peer.Close() if err != nil && opt.ResetOnCopyError { _ = stream.Reset(err) return } _ = stream.Close() }) } watchDone := make(chan struct{}) watchStopped := make(chan struct{}) go func() { defer close(watchStopped) select { case <-ctx.Done(): abortBridge(ctx.Err()) case <-watchDone: } }() runCopy := func(fn func() error) { wg.Add(1) go func() { defer wg.Done() err := fn() if err != nil { abortBridge(err) } errCh <- err }() } runCopy(func() error { _, err := CopyToStream(bridgeCtx, stream, peer, StreamCopyOptions{ BufferSize: opt.BufferSize, CloseWrite: true, }) if err != nil { cancel() } return err }) runCopy(func() error { _, err := copyFromStreamToBridgePeer(bridgeCtx, peer, stream, opt) if err != nil { cancel() } return err }) wg.Wait() close(errCh) close(watchDone) <-watchStopped if primaryErr != nil { if ctx.Err() != nil { return ctx.Err() } return primaryErr } var result error for err := range errCh { if err == nil { continue } if errors.Is(err, context.Canceled) && ctx.Err() == nil { continue } if result == nil { result = err } } return result } type streamBridgeCloseWriter interface { CloseWrite() error } func copyFromStreamToBridgePeer(ctx context.Context, peer io.ReadWriteCloser, stream Stream, opt StreamBridgeOptions) (int64, error) { written, err := CopyFromStream(ctx, peer, stream, StreamCopyOptions{ BufferSize: opt.BufferSize, }) if err != nil { return written, err } if closeWriter, ok := peer.(streamBridgeCloseWriter); ok { return written, closeWriter.CloseWrite() } if opt.ClosePeerOnEOF { return written, peer.Close() } return written, nil } func OpenClientStreamFromReader(ctx context.Context, c Client, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { if c == nil { return nil, 0, errStreamClientNil } return openStreamFromReader(ctx, src, opt, c.OpenStream) } func OpenServerLogicalStreamFromReader(ctx context.Context, s Server, logical *LogicalConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { if s == nil { return nil, 0, errStreamServerNil } if logical == nil { return nil, 0, errStreamLogicalConnNil } return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) { return s.OpenStreamLogical(ctx, logical, openOpt) }) } func OpenServerTransportStreamFromReader(ctx context.Context, s Server, transport *TransportConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { if s == nil { return nil, 0, errStreamServerNil } if transport == nil { return nil, 0, errStreamTransportNil } return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) { return s.OpenStreamTransport(ctx, transport, openOpt) }) } func CopyStreamToWriter(ctx context.Context, stream Stream, dst io.Writer, opt StreamCopyOptions) (int64, error) { return CopyFromStream(ctx, dst, stream, opt) } func openStreamFromReader(ctx context.Context, src io.Reader, opt StreamOpenCopyOptions, openFn func(context.Context, StreamOpenOptions) (Stream, error)) (Stream, int64, error) { if src == nil { return nil, 0, io.ErrClosedPipe } if openFn == nil { return nil, 0, io.ErrClosedPipe } opt = normalizeStreamOpenCopyOptions(opt) stream, err := openFn(ctx, opt.Open) if err != nil { return nil, 0, err } written, err := CopyToStream(ctx, stream, src, opt.Copy) if err != nil { _ = stream.Reset(err) return stream, written, err } return stream, written, nil } func normalizeStreamOpenCopyOptions(opt StreamOpenCopyOptions) StreamOpenCopyOptions { if !opt.Copy.CloseWrite && !opt.Copy.CloseStream { opt.Copy.CloseWrite = true } return opt }