289 lines
6.4 KiB
Go
289 lines
6.4 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"sync"
|
||
|
|
)
|
||
|
|
|
||
|
|
type StreamCopyOptions struct {
|
||
|
|
BufferSize int
|
||
|
|
CloseWrite bool
|
||
|
|
CloseStream bool
|
||
|
|
CloseWriter bool
|
||
|
|
}
|
||
|
|
|
||
|
|
type StreamBridgeOptions struct {
|
||
|
|
BufferSize int
|
||
|
|
ClosePeerOnEOF bool
|
||
|
|
ResetOnCopyError bool
|
||
|
|
}
|
||
|
|
|
||
|
|
type StreamOpenCopyOptions struct {
|
||
|
|
Open StreamOpenOptions
|
||
|
|
Copy StreamCopyOptions
|
||
|
|
}
|
||
|
|
|
||
|
|
func CopyToStream(ctx context.Context, stream Stream, src io.Reader, opt StreamCopyOptions) (int64, error) {
|
||
|
|
if stream == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if src == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
bufSize := opt.BufferSize
|
||
|
|
if bufSize <= 0 {
|
||
|
|
bufSize = defaultFileChunkSize
|
||
|
|
}
|
||
|
|
buf := make([]byte, bufSize)
|
||
|
|
reader := newContextReader(ctx, src)
|
||
|
|
written, err := io.CopyBuffer(stream, reader, buf)
|
||
|
|
if err == nil || err == io.EOF {
|
||
|
|
if opt.CloseWrite {
|
||
|
|
if closeErr := stream.CloseWrite(); closeErr != nil {
|
||
|
|
return written, closeErr
|
||
|
|
}
|
||
|
|
} else if opt.CloseStream {
|
||
|
|
if closeErr := stream.Close(); closeErr != nil {
|
||
|
|
return written, closeErr
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return written, normalizeStreamCopyError(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func CopyFromStream(ctx context.Context, dst io.Writer, stream Stream, opt StreamCopyOptions) (int64, error) {
|
||
|
|
if stream == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if dst == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
bufSize := opt.BufferSize
|
||
|
|
if bufSize <= 0 {
|
||
|
|
bufSize = defaultFileChunkSize
|
||
|
|
}
|
||
|
|
buf := make([]byte, bufSize)
|
||
|
|
reader := newContextReader(ctx, stream)
|
||
|
|
written, err := io.CopyBuffer(dst, reader, buf)
|
||
|
|
if (err == nil || err == io.EOF) && opt.CloseWriter {
|
||
|
|
if closer, ok := dst.(io.Closer); ok {
|
||
|
|
if closeErr := closer.Close(); closeErr != nil {
|
||
|
|
return written, closeErr
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return written, normalizeStreamCopyError(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
type contextReader struct {
|
||
|
|
ctx context.Context
|
||
|
|
src io.Reader
|
||
|
|
}
|
||
|
|
|
||
|
|
func newContextReader(ctx context.Context, src io.Reader) io.Reader {
|
||
|
|
if ctx == nil || src == nil {
|
||
|
|
return src
|
||
|
|
}
|
||
|
|
return &contextReader{ctx: ctx, src: src}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *contextReader) Read(p []byte) (int, error) {
|
||
|
|
if r == nil || r.src == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
select {
|
||
|
|
case <-r.ctx.Done():
|
||
|
|
return 0, r.ctx.Err()
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
return r.src.Read(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeStreamCopyError(err error) error {
|
||
|
|
if err == io.EOF {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func BridgeStream(ctx context.Context, stream Stream, peer io.ReadWriteCloser, opt StreamBridgeOptions) error {
|
||
|
|
if stream == nil || peer == nil {
|
||
|
|
return io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
bridgeCtx, cancel := context.WithCancel(ctx)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
errCh := make(chan error, 2)
|
||
|
|
var abortOnce sync.Once
|
||
|
|
var primaryErr error
|
||
|
|
|
||
|
|
abortBridge := func(err error) {
|
||
|
|
abortOnce.Do(func() {
|
||
|
|
primaryErr = err
|
||
|
|
cancel()
|
||
|
|
_ = peer.Close()
|
||
|
|
if err != nil && opt.ResetOnCopyError {
|
||
|
|
_ = stream.Reset(err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
_ = stream.Close()
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
watchDone := make(chan struct{})
|
||
|
|
watchStopped := make(chan struct{})
|
||
|
|
go func() {
|
||
|
|
defer close(watchStopped)
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
abortBridge(ctx.Err())
|
||
|
|
case <-watchDone:
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
runCopy := func(fn func() error) {
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
err := fn()
|
||
|
|
if err != nil {
|
||
|
|
abortBridge(err)
|
||
|
|
}
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
}
|
||
|
|
|
||
|
|
runCopy(func() error {
|
||
|
|
_, err := CopyToStream(bridgeCtx, stream, peer, StreamCopyOptions{
|
||
|
|
BufferSize: opt.BufferSize,
|
||
|
|
CloseWrite: true,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
cancel()
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
})
|
||
|
|
|
||
|
|
runCopy(func() error {
|
||
|
|
_, err := copyFromStreamToBridgePeer(bridgeCtx, peer, stream, opt)
|
||
|
|
if err != nil {
|
||
|
|
cancel()
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
})
|
||
|
|
|
||
|
|
wg.Wait()
|
||
|
|
close(errCh)
|
||
|
|
close(watchDone)
|
||
|
|
<-watchStopped
|
||
|
|
|
||
|
|
if primaryErr != nil {
|
||
|
|
if ctx.Err() != nil {
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
return primaryErr
|
||
|
|
}
|
||
|
|
|
||
|
|
var result error
|
||
|
|
for err := range errCh {
|
||
|
|
if err == nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if errors.Is(err, context.Canceled) && ctx.Err() == nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if result == nil {
|
||
|
|
result = err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
}
|
||
|
|
|
||
|
|
type streamBridgeCloseWriter interface {
|
||
|
|
CloseWrite() error
|
||
|
|
}
|
||
|
|
|
||
|
|
func copyFromStreamToBridgePeer(ctx context.Context, peer io.ReadWriteCloser, stream Stream, opt StreamBridgeOptions) (int64, error) {
|
||
|
|
written, err := CopyFromStream(ctx, peer, stream, StreamCopyOptions{
|
||
|
|
BufferSize: opt.BufferSize,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return written, err
|
||
|
|
}
|
||
|
|
if closeWriter, ok := peer.(streamBridgeCloseWriter); ok {
|
||
|
|
return written, closeWriter.CloseWrite()
|
||
|
|
}
|
||
|
|
if opt.ClosePeerOnEOF {
|
||
|
|
return written, peer.Close()
|
||
|
|
}
|
||
|
|
return written, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func OpenClientStreamFromReader(ctx context.Context, c Client, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
|
||
|
|
if c == nil {
|
||
|
|
return nil, 0, errStreamClientNil
|
||
|
|
}
|
||
|
|
return openStreamFromReader(ctx, src, opt, c.OpenStream)
|
||
|
|
}
|
||
|
|
|
||
|
|
func OpenServerLogicalStreamFromReader(ctx context.Context, s Server, logical *LogicalConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
|
||
|
|
if s == nil {
|
||
|
|
return nil, 0, errStreamServerNil
|
||
|
|
}
|
||
|
|
if logical == nil {
|
||
|
|
return nil, 0, errStreamLogicalConnNil
|
||
|
|
}
|
||
|
|
return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) {
|
||
|
|
return s.OpenStreamLogical(ctx, logical, openOpt)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func OpenServerTransportStreamFromReader(ctx context.Context, s Server, transport *TransportConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) {
|
||
|
|
if s == nil {
|
||
|
|
return nil, 0, errStreamServerNil
|
||
|
|
}
|
||
|
|
if transport == nil {
|
||
|
|
return nil, 0, errStreamTransportNil
|
||
|
|
}
|
||
|
|
return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) {
|
||
|
|
return s.OpenStreamTransport(ctx, transport, openOpt)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func CopyStreamToWriter(ctx context.Context, stream Stream, dst io.Writer, opt StreamCopyOptions) (int64, error) {
|
||
|
|
return CopyFromStream(ctx, dst, stream, opt)
|
||
|
|
}
|
||
|
|
|
||
|
|
func openStreamFromReader(ctx context.Context, src io.Reader, opt StreamOpenCopyOptions, openFn func(context.Context, StreamOpenOptions) (Stream, error)) (Stream, int64, error) {
|
||
|
|
if src == nil {
|
||
|
|
return nil, 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if openFn == nil {
|
||
|
|
return nil, 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
opt = normalizeStreamOpenCopyOptions(opt)
|
||
|
|
stream, err := openFn(ctx, opt.Open)
|
||
|
|
if err != nil {
|
||
|
|
return nil, 0, err
|
||
|
|
}
|
||
|
|
written, err := CopyToStream(ctx, stream, src, opt.Copy)
|
||
|
|
if err != nil {
|
||
|
|
_ = stream.Reset(err)
|
||
|
|
return stream, written, err
|
||
|
|
}
|
||
|
|
return stream, written, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeStreamOpenCopyOptions(opt StreamOpenCopyOptions) StreamOpenCopyOptions {
|
||
|
|
if !opt.Copy.CloseWrite && !opt.Copy.CloseStream {
|
||
|
|
opt.Copy.CloseWrite = true
|
||
|
|
}
|
||
|
|
return opt
|
||
|
|
}
|