notify/stream_helper.go

289 lines
6.4 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"io"
"sync"
)
type StreamCopyOptions struct {
BufferSize int
CloseWrite bool
CloseStream bool
CloseWriter bool
}
type StreamBridgeOptions struct {
BufferSize int
ClosePeerOnEOF bool
ResetOnCopyError bool
}
type StreamOpenCopyOptions struct {
Open StreamOpenOptions
Copy StreamCopyOptions
}
func CopyToStream(ctx context.Context, stream Stream, src io.Reader, opt StreamCopyOptions) (int64, error) {
if stream == nil {
return 0, io.ErrClosedPipe
}
if src == nil {
return 0, io.ErrClosedPipe
}
bufSize := opt.BufferSize
if bufSize <= 0 {
bufSize = defaultFileChunkSize
}
buf := make([]byte, bufSize)
reader := newContextReader(ctx, src)
written, err := io.CopyBuffer(stream, reader, buf)
if err == nil || err == io.EOF {
if opt.CloseWrite {
if closeErr := stream.CloseWrite(); closeErr != nil {
return written, closeErr
}
} else if opt.CloseStream {
if closeErr := stream.Close(); closeErr != nil {
return written, closeErr
}
}
}
return written, normalizeStreamCopyError(err)
}
func CopyFromStream(ctx context.Context, dst io.Writer, stream Stream, opt StreamCopyOptions) (int64, error) {
if stream == nil {
return 0, io.ErrClosedPipe
}
if dst == nil {
return 0, io.ErrClosedPipe
}
bufSize := opt.BufferSize
if bufSize <= 0 {
bufSize = defaultFileChunkSize
}
buf := make([]byte, bufSize)
reader := newContextReader(ctx, stream)
written, err := io.CopyBuffer(dst, reader, buf)
if (err == nil || err == io.EOF) && opt.CloseWriter {
if closer, ok := dst.(io.Closer); ok {
if closeErr := closer.Close(); closeErr != nil {
return written, closeErr
}
}
}
return written, normalizeStreamCopyError(err)
}
type contextReader struct {
ctx context.Context
src io.Reader
}
func newContextReader(ctx context.Context, src io.Reader) io.Reader {
if ctx == nil || src == nil {
return src
}
return &contextReader{ctx: ctx, src: src}
}
func (r *contextReader) Read(p []byte) (int, error) {
if r == nil || r.src == nil {
return 0, io.EOF
}
select {
case <-r.ctx.Done():
return 0, r.ctx.Err()
default:
}
return r.src.Read(p)
}
func normalizeStreamCopyError(err error) error {
if err == io.EOF {
return nil
}
return err
}
func BridgeStream(ctx context.Context, stream Stream, peer io.ReadWriteCloser, opt StreamBridgeOptions) error {
if stream == nil || peer == nil {
return io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
bridgeCtx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
errCh := make(chan error, 2)
var abortOnce sync.Once
var primaryErr error
abortBridge := func(err error) {
abortOnce.Do(func() {
primaryErr = err
cancel()
_ = peer.Close()
if err != nil && opt.ResetOnCopyError {
_ = stream.Reset(err)
return
}
_ = stream.Close()
})
}
watchDone := make(chan struct{})
watchStopped := make(chan struct{})
go func() {
defer close(watchStopped)
select {
case <-ctx.Done():
abortBridge(ctx.Err())
case <-watchDone:
}
}()
runCopy := func(fn func() error) {
wg.Add(1)
go func() {
defer wg.Done()
err := fn()
if err != nil {
abortBridge(err)
}
errCh <- err
}()
}
runCopy(func() error {
_, err := CopyToStream(bridgeCtx, stream, peer, StreamCopyOptions{
BufferSize: opt.BufferSize,
CloseWrite: true,
})
if err != nil {
cancel()
}
return err
})
runCopy(func() error {
_, err := copyFromStreamToBridgePeer(bridgeCtx, peer, stream, opt)
if err != nil {
cancel()
}
return err
})
wg.Wait()
close(errCh)
close(watchDone)
<-watchStopped
if primaryErr != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return primaryErr
}
var result error
for err := range errCh {
if err == nil {
continue
}
if errors.Is(err, context.Canceled) && ctx.Err() == nil {
continue
}
if result == nil {
result = err
}
}
return result
}
type streamBridgeCloseWriter interface {
CloseWrite() error
}
func copyFromStreamToBridgePeer(ctx context.Context, peer io.ReadWriteCloser, stream Stream, opt StreamBridgeOptions) (int64, error) {
written, err := CopyFromStream(ctx, peer, stream, StreamCopyOptions{
BufferSize: opt.BufferSize,
})
if err != nil {
return written, err
}
if closeWriter, ok := peer.(streamBridgeCloseWriter); ok {
return written, closeWriter.CloseWrite()
}
if opt.ClosePeerOnEOF {
return written, peer.Close()
}
return written, nil
}
func OpenClientStreamFromReader(ctx context.Context, c Client, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
if c == nil {
return nil, 0, errStreamClientNil
}
return openStreamFromReader(ctx, src, opt, c.OpenStream)
}
func OpenServerLogicalStreamFromReader(ctx context.Context, s Server, logical *LogicalConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
if s == nil {
return nil, 0, errStreamServerNil
}
if logical == nil {
return nil, 0, errStreamLogicalConnNil
}
return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) {
return s.OpenStreamLogical(ctx, logical, openOpt)
})
}
func OpenServerTransportStreamFromReader(ctx context.Context, s Server, transport *TransportConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
if s == nil {
return nil, 0, errStreamServerNil
}
if transport == nil {
return nil, 0, errStreamTransportNil
}
return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) {
return s.OpenStreamTransport(ctx, transport, openOpt)
})
}
func CopyStreamToWriter(ctx context.Context, stream Stream, dst io.Writer, opt StreamCopyOptions) (int64, error) {
return CopyFromStream(ctx, dst, stream, opt)
}
func openStreamFromReader(ctx context.Context, src io.Reader, opt StreamOpenCopyOptions, openFn func(context.Context, StreamOpenOptions) (Stream, error)) (Stream, int64, error) {
if src == nil {
return nil, 0, io.ErrClosedPipe
}
if openFn == nil {
return nil, 0, io.ErrClosedPipe
}
opt = normalizeStreamOpenCopyOptions(opt)
stream, err := openFn(ctx, opt.Open)
if err != nil {
return nil, 0, err
}
written, err := CopyToStream(ctx, stream, src, opt.Copy)
if err != nil {
_ = stream.Reset(err)
return stream, written, err
}
return stream, written, nil
}
func normalizeStreamOpenCopyOptions(opt StreamOpenCopyOptions) StreamOpenCopyOptions {
if !opt.Copy.CloseWrite && !opt.Copy.CloseStream {
opt.Copy.CloseWrite = true
}
return opt
}