diff --git a/circle.go b/circle.go index 6d96c4f..de36679 100644 --- a/circle.go +++ b/circle.go @@ -10,6 +10,10 @@ var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be grea var ErrStarBufferClosed = errors.New("star buffer closed") var ErrStarBufferWriteClosed = errors.New("star buffer write closed") +// StarBuffer is a blocking ring buffer that implements stream-style reads and writes. +// +// Close marks the write side finished after all payload bytes are sent. +// Abort aborts both sides immediately but still allows buffered bytes to be drained. type StarBuffer struct { datas []byte pStart uint64 @@ -36,18 +40,21 @@ func NewStarBuffer(cap uint64) (*StarBuffer, error) { return rtnBuffer, nil } +// Free returns the remaining writable capacity. func (star *StarBuffer) Free() uint64 { star.mu.Lock() defer star.mu.Unlock() return star.cap - star.size } +// Cap returns the fixed buffer capacity. func (star *StarBuffer) Cap() uint64 { star.mu.Lock() defer star.mu.Unlock() return star.cap } +// Len returns the currently buffered byte count. func (star *StarBuffer) Len() uint64 { star.mu.Lock() defer star.mu.Unlock() @@ -89,19 +96,21 @@ func (star *StarBuffer) putByte(data byte) error { return nil } -func (star *StarBuffer) EndWrite() error { +// Close closes only the write side and satisfies the usual io.Closer-style +// "producer finished" semantics. +// +// Buffered bytes remain readable until drained; afterwards reads return io.EOF. +func (star *StarBuffer) Close() error { star.mu.Lock() defer star.mu.Unlock() - if star.isClose { - return ErrStarBufferClosed - } - star.isWriteEnd = true - star.notEmpty.Broadcast() - star.notFull.Broadcast() - return nil + return star.closeWriteLocked() } -func (star *StarBuffer) Close() error { +// Abort aborts the buffer and wakes blocked readers/writers immediately. +// +// Buffered bytes remain readable until drained; subsequent writes fail with +// ErrStarBufferClosed. +func (star *StarBuffer) Abort() error { star.mu.Lock() defer star.mu.Unlock() star.isClose = true @@ -111,10 +120,17 @@ func (star *StarBuffer) Close() error { return nil } -func (star *StarBuffer) Read(buf []byte) (int, error) { - if buf == nil { - return 0, errors.New("buffer is nil") +func (star *StarBuffer) closeWriteLocked() error { + if star.isClose { + return ErrStarBufferClosed } + star.isWriteEnd = true + star.notEmpty.Broadcast() + star.notFull.Broadcast() + return nil +} + +func (star *StarBuffer) Read(buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } @@ -140,9 +156,6 @@ func (star *StarBuffer) Read(buf []byte) (int, error) { } func (star *StarBuffer) Write(bts []byte) (int, error) { - if bts == nil { - return 0, star.EndWrite() - } if len(bts) == 0 { return 0, nil } diff --git a/circle_benchmark_test.go b/circle_benchmark_test.go new file mode 100644 index 0000000..d783701 --- /dev/null +++ b/circle_benchmark_test.go @@ -0,0 +1,44 @@ +package stario + +import ( + "io" + "testing" +) + +func BenchmarkStarBufferByteRoundTrip(b *testing.B) { + buf, err := NewStarBuffer(4096) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.SetBytes(1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := buf.putByte('a'); err != nil { + b.Fatal(err) + } + if _, err := buf.getByte(); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkStarBufferChunkRoundTrip(b *testing.B) { + buf, err := NewStarBuffer(8192) + if err != nil { + b.Fatal(err) + } + payload := []byte("hello world b612 hello world b612 b612 b612 b612 b612 b612") + scratch := make([]byte, len(payload)) + b.ReportAllocs() + b.SetBytes(int64(len(payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := buf.Write(payload); err != nil { + b.Fatal(err) + } + if _, err := io.ReadFull(buf, scratch); err != nil { + b.Fatal(err) + } + } +} diff --git a/circle_test.go b/circle_test.go index cfec5bd..7b92419 100644 --- a/circle_test.go +++ b/circle_test.go @@ -2,11 +2,8 @@ package stario import ( "bytes" - "fmt" "io" - "sync/atomic" "testing" - "time" ) func TestNewStarBufferRejectsZeroCapacity(t *testing.T) { @@ -19,7 +16,7 @@ func TestNewStarBufferRejectsZeroCapacity(t *testing.T) { } } -func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) { +func TestStarBufferCloseDrainsThenEOF(t *testing.T) { buf, err := NewStarBuffer(4) if err != nil { t.Fatal(err) @@ -27,7 +24,7 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) { if _, err := buf.Write([]byte("abcd")); err != nil { t.Fatal(err) } - if err := buf.EndWrite(); err != nil { + if err := buf.Close(); err != nil { t.Fatal(err) } @@ -46,11 +43,11 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) { } if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed { - t.Fatalf("unexpected write error after EndWrite: %v", err) + t.Fatalf("unexpected write error after Close: %v", err) } } -func TestStarBufferCloseAllowsDrain(t *testing.T) { +func TestStarBufferAbortAllowsDrain(t *testing.T) { buf, err := NewStarBuffer(4) if err != nil { t.Fatal(err) @@ -58,7 +55,7 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) { if _, err := buf.Write([]byte("ab")); err != nil { t.Fatal(err) } - if err := buf.Close(); err != nil { + if err := buf.Abort(); err != nil { t.Fatal(err) } @@ -68,102 +65,38 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) { t.Fatal(err) } if n != 2 || !bytes.Equal(got[:n], []byte("ab")) { - t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n]) + t.Fatalf("unexpected payload after abort: n=%d data=%q", n, got[:n]) } n, err = buf.Read(got) if n != 0 || err != io.EOF { - t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err) + t.Fatalf("expected EOF after draining aborted buffer, got n=%d err=%v", n, err) } if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed { - t.Fatalf("unexpected write error after Close: %v", err) + t.Fatalf("unexpected write error after Abort: %v", err) } } -func Test_Circle(t *testing.T) { - buf, err := NewStarBuffer(2048) +func TestStarBufferNilReadIsNoOp(t *testing.T) { + buf, err := NewStarBuffer(4) if err != nil { t.Fatal(err) } - go func() { - for { - //fmt.Println("write start") - buf.Write([]byte("中华人民共和国\n")) - //fmt.Println("write success") - time.Sleep(time.Millisecond * 50) - } - }() - cpp := "" - go func() { - time.Sleep(time.Second * 3) - for { - cache := make([]byte, 64) - ints, err := buf.Read(cache) - if err != nil { - fmt.Println("read error", err) - return - } - if ints != 0 { - cpp += string(cache[:ints]) - } - } - }() - time.Sleep(time.Second * 13) - fmt.Println(cpp) + if n, err := buf.Read(nil); n != 0 || err != nil { + t.Fatalf("expected nil read to be a no-op, got n=%d err=%v", n, err) + } } -func Test_Circle_Speed(t *testing.T) { - buf, err := NewStarBuffer(1048976) +func TestStarBufferNilWriteIsNoOp(t *testing.T) { + buf, err := NewStarBuffer(4) if err != nil { t.Fatal(err) } - count := uint64(0) - for i := 1; i <= 10; i++ { - go func() { - for { - buf.putByte('a') - } - }() + if n, err := buf.Write(nil); n != 0 || err != nil { + t.Fatalf("expected nil write to be a no-op, got n=%d err=%v", n, err) } - for i := 1; i <= 10; i++ { - go func() { - for { - _, err := buf.getByte() - if err == nil { - atomic.AddUint64(&count, 1) - } - } - }() + if _, err := buf.Write([]byte("ab")); err != nil { + t.Fatalf("nil write must not end the write side, got %v", err) } - time.Sleep(time.Second * 10) - fmt.Println(count) -} - -func Test_Circle_Speed2(t *testing.T) { - buf, err := NewStarBuffer(8192) - if err != nil { - t.Fatal(err) - } - count := uint64(0) - for i := 1; i <= 10; i++ { - go func() { - for { - buf.Write([]byte("hello world b612 hello world b612 b612 b612 b612 b612 b612")) - } - }() - } - for i := 1; i <= 10; i++ { - go func() { - for { - mybuf := make([]byte, 1024) - j, err := buf.Read(mybuf) - if err == nil { - atomic.AddUint64(&count, uint64(j)) - } - } - }() - } - time.Sleep(time.Second * 10) - fmt.Println(float64(count) / 10 / 1024 / 1024) } diff --git a/fn.go b/fn.go index b82370b..b7f70eb 100644 --- a/fn.go +++ b/fn.go @@ -1,55 +1,111 @@ package stario import ( + "context" "errors" "time" ) +// ERR_TIMEOUT is the legacy timeout sentinel used by WaitUntilTimeout*. var ERR_TIMEOUT = errors.New("TIME OUT") -func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error { - var err error - finished := make(chan struct{}) - imout := make(chan struct{}) +// WaitUntilContext runs fn and returns either its result or the context error, +// whichever happens first. +func WaitUntilContext(ctx context.Context, fn func(context.Context) error) error { + if ctx == nil { + ctx = context.Background() + } + finished := make(chan error, 1) go func() { - err = fn(imout) - finished <- struct{}{} + finished <- fn(ctx) }() select { - case <-finished: + case err := <-finished: return err - case <-time.After(tm): - close(imout) - return ERR_TIMEOUT + case <-ctx.Done(): + return ctx.Err() } } -func WaitUntilFinished(fn func() error) <-chan error { - finished := make(chan error) +// WaitUntilContextFinished is the asynchronous form of WaitUntilContext. +func WaitUntilContextFinished(ctx context.Context, fn func(context.Context) error) <-chan error { + result := make(chan error, 1) go func() { - err := fn() - finished <- err - }() - return finished -} - -func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error { - var err error - finished := make(chan struct{}) - result := make(chan error) - imout := make(chan struct{}) - go func() { - err = fn(imout) - finished <- struct{}{} - }() - go func() { - select { - case <-finished: - result <- err - case <-time.After(tm): - close(imout) - result <- ERR_TIMEOUT - } + result <- WaitUntilContext(ctx, fn) + close(result) }() return result } + +// WaitUntilContextDone adapts a done-channel worker to a context-based wait. +func WaitUntilContextDone(ctx context.Context, fn func(<-chan struct{}) error) error { + if ctx == nil { + ctx = context.Background() + } + return WaitUntilContext(ctx, func(context.Context) error { + return fn(ctx.Done()) + }) +} + +// WaitUntilContextDoneFinished is the asynchronous form of WaitUntilContextDone. +func WaitUntilContextDoneFinished(ctx context.Context, fn func(<-chan struct{}) error) <-chan error { + return WaitUntilContextFinished(ctx, func(ctx context.Context) error { + return fn(ctx.Done()) + }) +} + +// WaitUntilTimeout is a legacy timeout helper kept for compatibility. +// +// The provided stop channel must be treated as receive-only by callers. New +// code should prefer WaitUntilContext or WaitUntilContextDone. +func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error { + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error { + return fn(bridgeDoneChan(done)) + }) + if errors.Is(err, context.DeadlineExceeded) { + return ERR_TIMEOUT + } + return err +} + +// WaitUntilFinished runs fn asynchronously and returns its eventual result. +func WaitUntilFinished(fn func() error) <-chan error { + return WaitUntilContextFinished(context.Background(), func(ctx context.Context) error { + return fn() + }) +} + +// WaitUntilTimeoutFinished is the asynchronous form of WaitUntilTimeout. +// +// The provided stop channel must be treated as receive-only by callers. New +// code should prefer WaitUntilContextFinished or WaitUntilContextDoneFinished. +func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error { + result := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error { + return fn(bridgeDoneChan(done)) + }) + if errors.Is(err, context.DeadlineExceeded) { + err = ERR_TIMEOUT + } + result <- err + close(result) + }() + return result +} + +func bridgeDoneChan(done <-chan struct{}) chan struct{} { + stop := make(chan struct{}) + if done == nil { + return stop + } + go func() { + <-done + close(stop) + }() + return stop +} diff --git a/fn_test.go b/fn_test.go new file mode 100644 index 0000000..8ed89d5 --- /dev/null +++ b/fn_test.go @@ -0,0 +1,174 @@ +package stario + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestWaitUntilContextReturnsWorkerError(t *testing.T) { + want := errors.New("worker failed") + err := WaitUntilContext(context.Background(), func(ctx context.Context) error { + return want + }) + if !errors.Is(err, want) { + t.Fatalf("unexpected error: got %v want %v", err, want) + } +} + +func TestWaitUntilContextReturnsDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + workerReturned := make(chan struct{}) + err := WaitUntilContext(ctx, func(ctx context.Context) error { + <-ctx.Done() + close(workerReturned) + return nil + }) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("unexpected context error: %v", err) + } + select { + case <-workerReturned: + case <-time.After(200 * time.Millisecond): + t.Fatal("worker did not return after context deadline") + } +} + +func TestWaitUntilContextFinishedReturnsCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := WaitUntilContextFinished(ctx, func(ctx context.Context) error { + <-ctx.Done() + return nil + }) + err, ok := <-result + if !ok { + t.Fatal("result channel closed without value") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected context error: %v", err) + } + if _, ok := <-result; ok { + t.Fatal("result channel should be closed after delivering the outcome") + } +} + +func TestWaitUntilContextDoneBridgesDoneChannel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + doneClosed := make(chan struct{}) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error { + <-done + close(doneClosed) + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected context error: %v", err) + } + select { + case <-doneClosed: + case <-time.After(200 * time.Millisecond): + t.Fatal("done channel was not closed when context canceled") + } +} + +func TestWaitUntilContextDoneFinishedReturnsCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + doneClosed := make(chan struct{}) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + result := WaitUntilContextDoneFinished(ctx, func(done <-chan struct{}) error { + <-done + close(doneClosed) + return nil + }) + err, ok := <-result + if !ok { + t.Fatal("result channel closed without value") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected context error: %v", err) + } + select { + case <-doneClosed: + case <-time.After(200 * time.Millisecond): + t.Fatal("done channel was not closed when context canceled") + } + if _, ok := <-result; ok { + t.Fatal("result channel should be closed after delivering the outcome") + } +} + +func TestWaitUntilTimeoutReturnsWorkerError(t *testing.T) { + want := errors.New("worker failed") + err := WaitUntilTimeout(time.Second, func(stop chan struct{}) error { + return want + }) + if !errors.Is(err, want) { + t.Fatalf("unexpected error: got %v want %v", err, want) + } +} + +func TestWaitUntilTimeoutTimesOutWithoutBlockingWorkerReturn(t *testing.T) { + workerReturned := make(chan struct{}) + err := WaitUntilTimeout(20*time.Millisecond, func(stop chan struct{}) error { + <-stop + close(workerReturned) + return nil + }) + if !errors.Is(err, ERR_TIMEOUT) { + t.Fatalf("unexpected timeout error: %v", err) + } + select { + case <-workerReturned: + case <-time.After(200 * time.Millisecond): + t.Fatal("worker did not return after timeout signal") + } +} + +func TestWaitUntilFinishedReturnsWorkerError(t *testing.T) { + want := errors.New("worker failed") + err := <-WaitUntilFinished(func() error { + return want + }) + if !errors.Is(err, want) { + t.Fatalf("unexpected error: got %v want %v", err, want) + } +} + +func TestWaitUntilTimeoutFinishedReturnsTimeout(t *testing.T) { + result := WaitUntilTimeoutFinished(20*time.Millisecond, func(stop chan struct{}) error { + <-stop + return nil + }) + err, ok := <-result + if !ok { + t.Fatal("result channel closed without value") + } + if !errors.Is(err, ERR_TIMEOUT) { + t.Fatalf("unexpected timeout error: %v", err) + } + if _, ok := <-result; ok { + t.Fatal("result channel should be closed after delivering the outcome") + } +} + +func TestWaitUntilTimeoutFinishedReturnsWorkerError(t *testing.T) { + want := errors.New("worker failed") + err, ok := <-WaitUntilTimeoutFinished(time.Second, func(stop chan struct{}) error { + return want + }) + if !ok { + t.Fatal("result channel closed without value") + } + if !errors.Is(err, want) { + t.Fatalf("unexpected error: got %v want %v", err, want) + } +} diff --git a/frameio.go b/frameio.go new file mode 100644 index 0000000..d1e1ea8 --- /dev/null +++ b/frameio.go @@ -0,0 +1,191 @@ +package stario + +import ( + "errors" + "io" +) + +// DefaultFrameReaderBufferSize is the default transport read chunk size used by +// FrameReader. +const DefaultFrameReaderBufferSize = 32 * 1024 + +type frameReaderConnKey struct{} + +// FrameWriter adapts StarQueue framing helpers to an io.Writer. +type FrameWriter struct { + writer io.Writer + queue *StarQueue +} + +// NewFrameWriter creates a framing writer backed by queue. When queue is nil, a +// default StarQueue is created. +func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter { + if queue == nil { + queue = NewQueue() + } + return &FrameWriter{ + writer: writer, + queue: queue, + } +} + +// WriteFrame writes one framed payload. +func (writer *FrameWriter) WriteFrame(payload []byte) error { + if writer == nil || writer.writer == nil || writer.queue == nil { + return io.ErrClosedPipe + } + return writer.queue.WriteFrame(writer.writer, payload) +} + +// WriteFrameBuffers writes one framed payload using net.Buffers when possible. +func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error { + if writer == nil || writer.writer == nil || writer.queue == nil { + return io.ErrClosedPipe + } + return writer.queue.WriteFrameBuffers(writer.writer, payload) +} + +// WriteFramesBuffers writes multiple framed payloads in one batch when +// possible. +func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error { + if writer == nil || writer.writer == nil || writer.queue == nil { + return io.ErrClosedPipe + } + return writer.queue.WriteFramesBuffers(writer.writer, payloads...) +} + +// FrameReader adapts StarQueue parsing helpers to an io.Reader. +type FrameReader struct { + reader io.Reader + queue *StarQueue + connKey interface{} + readSize int + pending [][]byte + pendingErr error +} + +// NewFrameReader creates a framing reader backed by queue. When queue is nil, a +// default StarQueue is created. +func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader { + if queue == nil { + queue = NewQueue() + } + return &FrameReader{ + reader: reader, + queue: queue, + connKey: &frameReaderConnKey{}, + readSize: DefaultFrameReaderBufferSize, + } +} + +// SetReadBufferSize updates the underlying transport read chunk size. +func (reader *FrameReader) SetReadBufferSize(size int) { + if reader == nil || size <= 0 { + return + } + reader.readSize = size +} + +// SetConnKey overrides the internal queue connection key. +func (reader *FrameReader) SetConnKey(conn interface{}) error { + if reader == nil { + return io.ErrClosedPipe + } + if conn == nil { + reader.connKey = &frameReaderConnKey{} + return nil + } + if err := validateConnKey(conn); err != nil { + return err + } + reader.connKey = conn + return nil +} + +// Next returns the next framed payload. +func (reader *FrameReader) Next() ([]byte, error) { + if reader == nil || reader.reader == nil || reader.queue == nil { + return nil, io.ErrClosedPipe + } + if len(reader.pending) > 0 { + return reader.popPending(), nil + } + if reader.pendingErr != nil { + err := reader.pendingErr + reader.pendingErr = nil + return nil, err + } + if reader.readSize <= 0 { + reader.readSize = DefaultFrameReaderBufferSize + } + buf := make([]byte, reader.readSize) + for { + n, readErr := reader.reader.Read(buf) + if n > 0 { + parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error { + reader.pending = append(reader.pending, msg.Msg) + return nil + }) + err := reader.normalizeReadErr(parseErr, readErr) + if len(reader.pending) > 0 { + reader.pendingErr = err + return reader.popPending(), nil + } + if err != nil { + return nil, err + } + } + if readErr != nil { + return nil, reader.normalizeReadErr(nil, readErr) + } + } +} + +func (reader *FrameReader) popPending() []byte { + next := reader.pending[0] + reader.pending = reader.pending[1:] + return next +} + +func joinFrameReaderError(parseErr error, readErr error) error { + switch { + case parseErr == nil: + return readErr + case readErr == nil: + return parseErr + default: + return errors.Join(parseErr, readErr) + } +} + +func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error { + if errors.Is(readErr, io.EOF) && reader.hasBufferedState() { + reader.clearBufferedState() + readErr = io.ErrUnexpectedEOF + } + return joinFrameReaderError(parseErr, readErr) +} + +func (reader *FrameReader) hasBufferedState() bool { + if reader == nil || reader.queue == nil { + return false + } + stateAny, ok := reader.queue.states.Load(reader.connKey) + if !ok { + return false + } + state, ok := stateAny.(*queConnState) + if !ok { + return false + } + state.mu.Lock() + defer state.mu.Unlock() + return len(state.buf) > 0 +} + +func (reader *FrameReader) clearBufferedState() { + if reader == nil || reader.queue == nil { + return + } + reader.queue.states.Delete(reader.connKey) +} diff --git a/frameio_test.go b/frameio_test.go new file mode 100644 index 0000000..e08a0f0 --- /dev/null +++ b/frameio_test.go @@ -0,0 +1,105 @@ +package stario + +import ( + "bytes" + "errors" + "io" + "testing" +) + +type chunkedReader struct { + data []byte + max int +} + +func (reader *chunkedReader) Read(p []byte) (int, error) { + if len(reader.data) == 0 { + return 0, io.EOF + } + if reader.max > 0 && len(p) > reader.max { + p = p[:reader.max] + } + n := copy(p, reader.data) + reader.data = reader.data[n:] + return n, nil +} + +func TestFrameWriterReaderRoundTrip(t *testing.T) { + var wire bytes.Buffer + writer := NewFrameWriter(&wire, nil) + if err := writer.WriteFrame([]byte("one")); err != nil { + t.Fatalf("WriteFrame failed: %v", err) + } + if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil { + t.Fatalf("WriteFramesBuffers failed: %v", err) + } + + reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil) + cases := []string{"one", "two", "three"} + for _, want := range cases { + got, err := reader.Next() + if err != nil { + t.Fatalf("Next returned error: %v", err) + } + if string(got) != want { + t.Fatalf("unexpected frame: got %q want %q", got, want) + } + } + if _, err := reader.Next(); !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after draining frames, got %v", err) + } +} + +func TestFrameReaderHandlesPartialTransportReads(t *testing.T) { + var wire bytes.Buffer + writer := NewFrameWriter(&wire, nil) + if err := writer.WriteFrame([]byte("hello")); err != nil { + t.Fatalf("WriteFrame failed: %v", err) + } + + reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil) + reader.SetReadBufferSize(3) + got, err := reader.Next() + if err != nil { + t.Fatalf("Next returned error: %v", err) + } + if string(got) != "hello" { + t.Fatalf("unexpected frame: %q", got) + } +} + +func TestFrameWriterNilWriterFails(t *testing.T) { + writer := NewFrameWriter(nil, nil) + if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("expected io.ErrClosedPipe, got %v", err) + } +} + +func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) { + queue := NewQueue() + frame := queue.BuildMessage([]byte("hello")) + reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil) + + if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err) + } +} + +func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) { + queue := NewQueue() + first := queue.BuildMessage([]byte("one")) + second := queue.BuildMessage([]byte("two")) + wire := append(append([]byte{}, first...), second[:len(second)-1]...) + reader := NewFrameReader(bytes.NewReader(wire), nil) + + got, err := reader.Next() + if err != nil { + t.Fatalf("unexpected error on first frame: %v", err) + } + if string(got) != "one" { + t.Fatalf("unexpected first frame: %q", got) + } + if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err) + } +} diff --git a/go.mod b/go.mod index 0103522..fadce3f 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module b612.me/stario -go 1.16 +go 1.20 -require golang.org/x/crypto v0.26.0 +require golang.org/x/term v0.23.0 + +require golang.org/x/sys v0.23.0 // indirect diff --git a/go.sum b/go.sum index f3a76d6..704bf82 100644 --- a/go.sum +++ b/go.sum @@ -1,67 +1,4 @@ -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/io.go b/io.go index 12ec3f4..5610f23 100644 --- a/io.go +++ b/io.go @@ -3,11 +3,12 @@ package stario import ( "bufio" "fmt" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/term" + "io" "os" - "runtime" "strconv" "strings" + "sync" ) type InputMsg struct { @@ -16,30 +17,54 @@ type InputMsg struct { skipSliceSigErr bool } +type rawInputSignalMode uint8 + +const ( + rawInputSignalIgnore rawInputSignalMode = iota + rawInputSignalExit + rawInputSignalReturnError +) + type rawTerminalSession struct { fd int - state *terminal.State + state *term.State reader *bufio.Reader + input io.Closer redrawHint string printNewline bool + mu sync.Mutex } +var rawTerminalSessionFactory = newRawTerminalSession +var inputSignalHandler = signal + +// Passwd reads one password-style line in raw mode. +// +// When the user presses an input signal such as Ctrl+C, this compatibility +// entry exits the current flow and returns an empty message with a nil error. func Passwd(hint string, defaultVal string) InputMsg { - return passwd(hint, defaultVal, "", true) + return passwd(hint, defaultVal, "", rawInputSignalExit) } +// PasswdWithMask is like Passwd but echoes the provided mask string. func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg { - return passwd(hint, defaultVal, mask, true) + return passwd(hint, defaultVal, mask, rawInputSignalExit) } +// PasswdResponseSignal is like Passwd but preserves input-signal errors for +// callers that need to distinguish Ctrl+C / Ctrl+Z style exits. func PasswdResponseSignal(hint string, defaultVal string) InputMsg { - return passwd(hint, defaultVal, "", true) + return passwd(hint, defaultVal, "", rawInputSignalReturnError) } +// PasswdResponseSignalWithMask is like PasswdResponseSignal but echoes the +// provided mask string. func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg { - return passwd(hint, defaultVal, mask, true) + return passwd(hint, defaultVal, mask, rawInputSignalReturnError) } +// MessageBoxRaw reads one line in raw mode without treating control keys as +// exit signals. func MessageBoxRaw(hint string, defaultVal string) InputMsg { return messageBox(hint, defaultVal) } @@ -48,35 +73,64 @@ func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession, if hint != "" { fmt.Print(hint) } - fd := int(os.Stdin.Fd()) - state, err := terminal.MakeRaw(fd) + input, err := openRawTerminalInput() if err != nil { return nil, err } + fd := int(input.Fd()) + state, err := term.MakeRaw(fd) + if err != nil { + _ = input.Close() + return nil, err + } return &rawTerminalSession{ fd: fd, state: state, - reader: bufio.NewReader(os.Stdin), + reader: bufio.NewReader(input), + input: input, redrawHint: promptRedrawHint(hint), printNewline: printNewline, }, nil } func (session *rawTerminalSession) Close() { - if session == nil || session.state == nil { + if session == nil { return } - _ = terminal.Restore(session.fd, session.state) + session.mu.Lock() + defer session.mu.Unlock() + if session.state != nil { + _ = term.Restore(session.fd, session.state) + session.state = nil + } if session.printNewline { fmt.Println() + session.printNewline = false + } + if session.input != nil { + _ = session.input.Close() + session.input = nil } } func (session *rawTerminalSession) Restore() error { - if session == nil || session.state == nil { + if session == nil { return nil } - return terminal.Restore(session.fd, session.state) + session.mu.Lock() + defer session.mu.Unlock() + if session.state == nil { + return nil + } + if err := term.Restore(session.fd, session.state); err != nil { + return err + } + session.state = nil + return nil +} + +func (session *rawTerminalSession) Abort() { + session.Close() } func promptRedrawHint(hint string) string { @@ -101,6 +155,20 @@ func renderRawEcho(ioBuf []rune, mask string) string { return strings.Repeat(mask, len(ioBuf)) } +func rawEchoRenderUnit(r rune, mask string, maskWidth int) (string, int, bool) { + if mask != "" { + if maskWidth <= 0 { + return "", 0, false + } + return mask, maskWidth, true + } + width := runeDisplayWidth(r) + if width <= 0 { + return "", 0, false + } + return string(r), width, true +} + func redrawPromptLine(hint string, echo string, lastWidth int) int { nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo) clearWidth := lastWidth @@ -121,6 +189,22 @@ func redrawPromptLine(hint string, echo string, lastWidth int) int { return nowWidth } +func erasePromptTail(width int) { + if width <= 0 { + return + } + backtrack := strings.Repeat("\b", width) + fmt.Print(backtrack) + fmt.Print(strings.Repeat(" ", width)) + fmt.Print(backtrack) +} + +func redrawPromptEcho(hint string, ioBuf []rune, mask string, lastWidth int) (int, int) { + echo := renderRawEcho(ioBuf, mask) + echoWidth := stringDisplayWidth(echo) + return echoWidth, redrawPromptLine(hint, echo, lastWidth) +} + func stringDisplayWidth(text string) int { width := 0 for _, r := range text { @@ -157,46 +241,79 @@ func isWideRune(r rune) bool { (r >= 0x20000 && r <= 0x3fffd)) } -func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { - session, err := newRawTerminalSession(hint, true) +func signalInputResult(mode rawInputSignalMode, err error) InputMsg { + switch mode { + case rawInputSignalExit: + return InputMsg{msg: "", err: nil} + case rawInputSignalReturnError: + return InputMsg{msg: "", err: err} + default: + return InputMsg{msg: "", err: nil} + } +} + +func rawLineInput(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg { + session, err := rawTerminalSessionFactory(hint, true) if err != nil { return InputMsg{msg: "", err: err} } defer session.Close() + return rawLineInputSession(session, defaultVal, mask, signalMode) +} + +func rawLineInputSession(session *rawTerminalSession, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg { + if session == nil || session.reader == nil { + return InputMsg{msg: "", err: io.ErrClosedPipe} + } ioBuf := make([]rune, 0, 16) - lastWidth := 0 + promptWidth := stringDisplayWidth(session.redrawHint) + maskWidth := stringDisplayWidth(mask) + echoWidth := 0 + lastWidth := promptWidth for { b, _, err := session.reader.ReadRune() if err != nil { return InputMsg{msg: "", err: err} } - if handleSignal && isSignal(b) { - if runtime.GOOS != "windows" { - if err := session.Restore(); err != nil { - return InputMsg{msg: "", err: err} - } + if signalMode != rawInputSignalIgnore && isSignal(b) { + session.Close() + if signalMode == rawInputSignalExit { + return signalInputResult(signalMode, nil) } - if err := signal(b); err != nil { - return InputMsg{msg: "", err: err} - } - continue + return signalInputResult(signalMode, inputSignalHandler(b)) } switch b { case 0x0d, 0x0a: return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil} case 0x08, 0x7F: if len(ioBuf) > 0 { + removed := ioBuf[len(ioBuf)-1] ioBuf = ioBuf[:len(ioBuf)-1] + if _, removedWidth, ok := rawEchoRenderUnit(removed, mask, maskWidth); ok { + erasePromptTail(removedWidth) + echoWidth -= removedWidth + if echoWidth < 0 { + echoWidth = 0 + } + lastWidth = promptWidth + echoWidth + continue + } } default: ioBuf = append(ioBuf, b) + if appendText, appendWidth, ok := rawEchoRenderUnit(b, mask, maskWidth); ok { + fmt.Print(appendText) + echoWidth += appendWidth + lastWidth = promptWidth + echoWidth + continue + } } - lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth) + echoWidth, lastWidth = redrawPromptEcho(session.redrawHint, ioBuf, mask, lastWidth) } } func messageBox(hint string, defaultVal string) InputMsg { - return rawLineInput(hint, defaultVal, "", false) + return rawLineInput(hint, defaultVal, "", rawInputSignalIgnore) } func isSignal(s rune) bool { @@ -208,10 +325,12 @@ func isSignal(s rune) bool { } } -func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { - return rawLineInput(hint, defaultVal, mask, handleSignal) +func passwd(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg { + return rawLineInput(hint, defaultVal, mask, signalMode) } +// MessageBox reads one line in cooked mode and falls back to defaultVal when +// the trimmed input is empty. func MessageBox(hint string, defaultVal string) InputMsg { if hint != "" { fmt.Print(hint) @@ -264,7 +383,7 @@ func (im InputMsg) sliceFn(sep string, fn func(string) (interface{}, error)) ([] return res, err } for _, v := range data { - code, err := fn(v) + code, err := fn(strings.TrimSpace(v)) if err != nil && !im.skipSliceSigErr { return nil, err } else if err == nil { @@ -428,7 +547,8 @@ func (im InputMsg) MustFloat32() float32 { func (im InputMsg) SliceFloat32(sep string) ([]float32, error) { data, err := im.sliceFn(sep, func(v string) (interface{}, error) { - return strconv.ParseFloat(v, 32) + f, err := strconv.ParseFloat(v, 32) + return float32(f), err }) var res []float32 for _, v := range data { @@ -477,13 +597,47 @@ func YesNoE(hint string, defaults bool) (bool, error) { } } +func buildTriggerPrefixTable(triggerRunes []rune) []int { + if len(triggerRunes) == 0 { + return nil + } + prefix := make([]int, len(triggerRunes)) + for i := 1; i < len(triggerRunes); i++ { + j := prefix[i-1] + for j > 0 && triggerRunes[i] != triggerRunes[j] { + j = prefix[j-1] + } + if triggerRunes[i] == triggerRunes[j] { + j++ + } + prefix[i] = j + } + return prefix +} + +func advanceTriggerIndex(triggerRunes []rune, prefix []int, current int, input rune) (int, bool) { + if len(triggerRunes) == 0 { + return 0, true + } + for current > 0 && input != triggerRunes[current] { + current = prefix[current-1] + } + if input == triggerRunes[current] { + current++ + if current == len(triggerRunes) { + return current, true + } + } + return current, false +} + +// StopUntil keeps reading raw input until trigger is matched. +// When trigger == "", it returns after the first key press, which is used for +// "press any key to continue" style prompts. func StopUntil(hint string, trigger string, repeat bool) error { triggerRunes := []rune(trigger) - pressLen := len(triggerRunes) - if trigger == "" { - pressLen = 1 - } - session, err := newRawTerminalSession(hint, false) + prefix := buildTriggerPrefixTable(triggerRunes) + session, err := rawTerminalSessionFactory(hint, false) if err != nil { return err } @@ -497,11 +651,12 @@ func StopUntil(hint string, trigger string, repeat bool) error { if trigger == "" { break } - if b == triggerRunes[i] { - i++ - if i == pressLen { - break - } + next, complete := advanceTriggerIndex(triggerRunes, prefix, i, b) + if complete { + break + } + if next > 0 { + i = next continue } i = 0 diff --git a/io_benchmark_test.go b/io_benchmark_test.go new file mode 100644 index 0000000..33f994d --- /dev/null +++ b/io_benchmark_test.go @@ -0,0 +1,81 @@ +package stario + +import ( + "strings" + "testing" +) + +func benchmarkRawRedrawAppend(b *testing.B, current []rune, next rune, mask string) { + b.Run("legacy", func(b *testing.B) { + ioBuf := make([]rune, len(current)+1) + copy(ioBuf, current) + ioBuf[len(current)] = next + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + echo := renderRawEcho(ioBuf, mask) + _ = stringDisplayWidth(echo) + } + }) + + b.Run("fast", func(b *testing.B) { + maskWidth := stringDisplayWidth(mask) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, ok := rawEchoRenderUnit(next, mask, maskWidth) + if !ok { + b.Fatal("fast path rejected append rune") + } + } + }) +} + +func benchmarkRawRedrawBackspace(b *testing.B, current []rune, mask string) { + if len(current) == 0 { + b.Fatal("backspace benchmark requires non-empty input") + } + last := current[len(current)-1] + trimmed := current[:len(current)-1] + + b.Run("legacy", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + echo := renderRawEcho(trimmed, mask) + _ = stringDisplayWidth(echo) + } + }) + + b.Run("fast", func(b *testing.B) { + maskWidth := stringDisplayWidth(mask) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, ok := rawEchoRenderUnit(last, mask, maskWidth) + if !ok { + b.Fatal("fast path rejected backspace rune") + } + } + }) +} + +func BenchmarkRawRedrawAppendPlainLong(b *testing.B) { + current := []rune(strings.Repeat("a", 4096)) + benchmarkRawRedrawAppend(b, current, 'b', "") +} + +func BenchmarkRawRedrawAppendMaskLong(b *testing.B) { + current := []rune(strings.Repeat("a", 4096)) + benchmarkRawRedrawAppend(b, current, 'b', "[]") +} + +func BenchmarkRawRedrawBackspacePlainLong(b *testing.B) { + current := []rune(strings.Repeat("a", 4096)) + benchmarkRawRedrawBackspace(b, current, "") +} + +func BenchmarkRawRedrawBackspaceMaskLong(b *testing.B) { + current := []rune(strings.Repeat("a", 4096)) + benchmarkRawRedrawBackspace(b, current, "[]") +} diff --git a/io_context.go b/io_context.go new file mode 100644 index 0000000..440e5b5 --- /dev/null +++ b/io_context.go @@ -0,0 +1,177 @@ +package stario + +import ( + "context" + "errors" + "io" +) + +const defaultCopyContextBufferSize = 32 * 1024 + +type readContextResult struct { + data []byte + err error +} + +type writeContextResult struct { + n int + err error +} + +// ReadFullContext reads exactly len(buf) bytes unless the context is canceled +// or the underlying reader returns an error. +// +// If ctx is canceled while the underlying Read call is blocked, ReadFullContext +// returns ctx.Err() without waiting for that call to finish. The underlying +// reader may still complete asynchronously afterwards. +func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) { + if reader == nil { + return 0, io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return 0, err + } + total := 0 + for total < len(buf) { + chunk, err := readContext(ctx, reader, len(buf)-total) + if len(chunk) > 0 { + total += copy(buf[total:], chunk) + } + if err != nil { + if errors.Is(err, io.EOF) { + if total > 0 { + return total, io.ErrUnexpectedEOF + } + return total, io.EOF + } + return total, err + } + if len(chunk) == 0 { + return total, io.ErrNoProgress + } + } + return total, nil +} + +// WriteFullContext writes the full payload unless the context is canceled or +// the underlying writer returns an error. +// +// If ctx is canceled while the underlying Write call is blocked, +// WriteFullContext returns ctx.Err() without waiting for that call to finish. +// The underlying writer may still complete asynchronously afterwards. +func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) { + if writer == nil { + return 0, io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return 0, err + } + total := 0 + for total < len(data) { + written, err := writeContext(ctx, writer, data[total:]) + if written > 0 { + total += written + } + if err != nil { + return total, err + } + if written == 0 { + return total, io.ErrNoProgress + } + } + return total, nil +} + +// CopyContext copies from src to dst until EOF, context cancellation, or a +// non-EOF error occurs. +// +// If ctx is canceled while the current read or write is blocked, CopyContext +// returns ctx.Err() without waiting for that operation to finish. +func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + if dst == nil || src == nil { + return 0, io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return 0, err + } + var copied int64 + for { + chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize) + if len(chunk) > 0 { + written, writeErr := WriteFullContext(ctx, dst, chunk) + copied += int64(written) + if writeErr != nil { + return copied, writeErr + } + } + if readErr != nil { + if errors.Is(readErr, io.EOF) { + return copied, nil + } + return copied, readErr + } + if len(chunk) == 0 { + return copied, io.ErrNoProgress + } + } +} + +func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) { + if size <= 0 { + return nil, nil + } + resultCh := make(chan readContextResult, 1) + go func() { + tmp := make([]byte, size) + n, err := reader.Read(tmp) + if n > 0 { + tmp = tmp[:n] + } else { + tmp = nil + } + resultCh <- readContextResult{data: tmp, err: err} + }() + select { + case result := <-resultCh: + return result.data, result.err + case <-ctx.Done(): + select { + case result := <-resultCh: + return result.data, result.err + default: + return nil, ctx.Err() + } + } +} + +func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } + payload := append([]byte(nil), data...) + resultCh := make(chan writeContextResult, 1) + go func() { + n, err := writer.Write(payload) + resultCh <- writeContextResult{n: n, err: err} + }() + select { + case result := <-resultCh: + return result.n, result.err + case <-ctx.Done(): + select { + case result := <-resultCh: + return result.n, result.err + default: + return 0, ctx.Err() + } + } +} diff --git a/io_context_test.go b/io_context_test.go new file mode 100644 index 0000000..a9285db --- /dev/null +++ b/io_context_test.go @@ -0,0 +1,153 @@ +package stario + +import ( + "bytes" + "context" + "errors" + "io" + "sync" + "testing" + "time" +) + +type chunkedWriter struct { + buf bytes.Buffer + max int +} + +func (writer *chunkedWriter) Write(p []byte) (int, error) { + if writer.max <= 0 || len(p) <= writer.max { + return writer.buf.Write(p) + } + return writer.buf.Write(p[:writer.max]) +} + +type blockingReader struct { + started chan struct{} + release chan struct{} + data []byte + once sync.Once +} + +func (reader *blockingReader) Read(p []byte) (int, error) { + reader.once.Do(func() { close(reader.started) }) + <-reader.release + n := copy(p, reader.data) + return n, nil +} + +type blockingWriter struct { + started chan struct{} + release chan struct{} + once sync.Once +} + +func (writer *blockingWriter) Write(p []byte) (int, error) { + writer.once.Do(func() { close(writer.started) }) + <-writer.release + return len(p), nil +} + +func TestReadFullContext(t *testing.T) { + buf := make([]byte, 5) + n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf) + if err != nil { + t.Fatalf("ReadFullContext returned error: %v", err) + } + if n != 5 || string(buf) != "hello" { + t.Fatalf("unexpected payload: n=%d data=%q", n, buf) + } +} + +func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) { + buf := make([]byte, 5) + n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("expected unexpected EOF, got %v", err) + } + if n != 3 || string(buf[:n]) != "hey" { + t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n]) + } +} + +func TestReadFullContextCanceledWhileBlocked(t *testing.T) { + reader := &blockingReader{ + started: make(chan struct{}), + release: make(chan struct{}), + data: []byte("hello"), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + buf := make([]byte, 5) + _, err := ReadFullContext(ctx, reader, buf) + done <- err + }() + + <-reader.started + cancel() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("ReadFullContext did not return after cancel") + } + + close(reader.release) +} + +func TestWriteFullContext(t *testing.T) { + writer := &chunkedWriter{max: 2} + n, err := WriteFullContext(context.Background(), writer, []byte("hello")) + if err != nil { + t.Fatalf("WriteFullContext returned error: %v", err) + } + if n != 5 || writer.buf.String() != "hello" { + t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String()) + } +} + +func TestWriteFullContextCanceledWhileBlocked(t *testing.T) { + writer := &blockingWriter{ + started: make(chan struct{}), + release: make(chan struct{}), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := WriteFullContext(ctx, writer, []byte("hello")) + done <- err + }() + + <-writer.started + cancel() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("WriteFullContext did not return after cancel") + } + + close(writer.release) +} + +func TestCopyContext(t *testing.T) { + var dst bytes.Buffer + written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world")) + if err != nil { + t.Fatalf("CopyContext returned error: %v", err) + } + if written != int64(len("hello world")) || dst.String() != "hello world" { + t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String()) + } +} diff --git a/io_test.go b/io_test.go index f42af27..457daf1 100644 --- a/io_test.go +++ b/io_test.go @@ -1,10 +1,32 @@ package stario import ( + "bufio" + "errors" "fmt" + "strings" "testing" ) +func installRawInputStub(t *testing.T, input string, signalErr error) { + t.Helper() + prevFactory := rawTerminalSessionFactory + prevSignalHandler := inputSignalHandler + rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) { + return &rawTerminalSession{ + reader: bufio.NewReader(strings.NewReader(input)), + redrawHint: promptRedrawHint(hint), + }, nil + } + inputSignalHandler = func(sigtype rune) error { + return signalErr + } + t.Cleanup(func() { + rawTerminalSessionFactory = prevFactory + inputSignalHandler = prevSignalHandler + }) +} + func TestPromptRedrawHint(t *testing.T) { got := promptRedrawHint("头部提示\n 中文确认: ") if got != "中文确认:" { @@ -19,6 +41,81 @@ func TestStringDisplayWidth(t *testing.T) { } } +func TestRawEchoRenderUnitPlainRune(t *testing.T) { + text, width, ok := rawEchoRenderUnit('中', "", 0) + if !ok { + t.Fatal("expected plain wide rune to use fast path") + } + if text != "中" || width != 2 { + t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "中", 2) + } +} + +func TestRawEchoRenderUnitMaskedRune(t *testing.T) { + text, width, ok := rawEchoRenderUnit('a', "[]", stringDisplayWidth("[]")) + if !ok { + t.Fatal("expected masked rune to use fast path") + } + if text != "[]" || width != 2 { + t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "[]", 2) + } +} + +func TestRawEchoRenderUnitFallsBackForControlRune(t *testing.T) { + if _, _, ok := rawEchoRenderUnit('\x00', "", 0); ok { + t.Fatal("expected control rune to fall back to full redraw") + } +} + +func TestSignalInputResultExitSuppressesError(t *testing.T) { + got := signalInputResult(rawInputSignalExit, ErrSignalInterrupt) + if got.err != nil { + t.Fatalf("expected nil error, got %v", got.err) + } + if got.msg != "" { + t.Fatalf("expected empty message, got %q", got.msg) + } +} + +func TestSignalInputResultReturnErrorPreservesSignal(t *testing.T) { + got := signalInputResult(rawInputSignalReturnError, ErrSignalInterrupt) + if !errors.Is(got.err, ErrSignalInterrupt) { + t.Fatalf("expected signal error, got %v", got.err) + } + if got.msg != "" { + t.Fatalf("expected empty message, got %q", got.msg) + } +} + +func TestPasswdSuppressesSignalError(t *testing.T) { + installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt) + got := Passwd("", "fallback") + if got.err != nil { + t.Fatalf("expected nil error, got %v", got.err) + } + if got.msg != "" { + t.Fatalf("expected empty message after signal exit, got %q", got.msg) + } +} + +func TestPasswdResponseSignalPreservesSignalError(t *testing.T) { + installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt) + got := PasswdResponseSignal("", "fallback") + if !errors.Is(got.err, ErrSignalInterrupt) { + t.Fatalf("expected interrupt error, got %v", got.err) + } + if got.msg != "" { + t.Fatalf("expected empty message after signal exit, got %q", got.msg) + } +} + +func TestStopUntilEmptyTriggerReturnsAfterFirstKey(t *testing.T) { + installRawInputStub(t, "abc", nil) + if err := StopUntil("", "", false); err != nil { + t.Fatalf("StopUntil returned error: %v", err) + } +} + func TestParseYesNoValue(t *testing.T) { cases := []struct { name string @@ -40,6 +137,59 @@ func TestParseYesNoValue(t *testing.T) { } } +func TestSliceFloat32(t *testing.T) { + data := InputMsg{msg: "1.5,2.25", err: nil} + got, err := data.SliceFloat32(",") + if err != nil { + t.Fatalf("SliceFloat32 returned error: %v", err) + } + if len(got) != 2 || got[0] != float32(1.5) || got[1] != float32(2.25) { + t.Fatalf("unexpected float32 slice: %#v", got) + } +} + +func TestTypedSliceParsingTrimsTokenWhitespace(t *testing.T) { + ints, err := (InputMsg{msg: "1, 2, 3"}).SliceInt(",") + if err != nil { + t.Fatalf("SliceInt returned error: %v", err) + } + if len(ints) != 3 || ints[0] != 1 || ints[1] != 2 || ints[2] != 3 { + t.Fatalf("unexpected int slice: %#v", ints) + } + + bools, err := (InputMsg{msg: "true, false, true"}).SliceBool(",") + if err != nil { + t.Fatalf("SliceBool returned error: %v", err) + } + if len(bools) != 3 || !bools[0] || bools[1] || !bools[2] { + t.Fatalf("unexpected bool slice: %#v", bools) + } + + float64s, err := (InputMsg{msg: "1.25, 2.5, 3.75"}).SliceFloat64(",") + if err != nil { + t.Fatalf("SliceFloat64 returned error: %v", err) + } + if len(float64s) != 3 || float64s[0] != 1.25 || float64s[1] != 2.5 || float64s[2] != 3.75 { + t.Fatalf("unexpected float64 slice: %#v", float64s) + } +} + +func TestAdvanceTriggerIndexHandlesOverlap(t *testing.T) { + trigger := []rune("aba") + prefix := buildTriggerPrefixTable(trigger) + index := 0 + complete := false + for _, r := range []rune("aaba") { + index, complete = advanceTriggerIndex(trigger, prefix, index, r) + if complete { + break + } + } + if !complete { + t.Fatal("expected overlapped trigger to complete") + } +} + func Test_Slice(t *testing.T) { var data = InputMsg{ msg: "true,false,true,true,false,0,1,hello", diff --git a/pipe.go b/pipe.go new file mode 100644 index 0000000..4272a7f --- /dev/null +++ b/pipe.go @@ -0,0 +1,71 @@ +package stario + +import "io" + +// StarPipeReader is the read side returned by NewStarPipe. +type StarPipeReader struct { + buf *StarBuffer +} + +// StarPipeWriter is the write side returned by NewStarPipe. +type StarPipeWriter struct { + buf *StarBuffer +} + +// NewStarPipe creates a buffered in-memory pipe backed by StarBuffer. +// +// The writer side uses Close to signal a graceful end-of-stream and Abort to +// fail both sides immediately. +func NewStarPipe(capacity uint64) (*StarPipeReader, *StarPipeWriter, error) { + buf, err := NewStarBuffer(capacity) + if err != nil { + return nil, nil, err + } + return &StarPipeReader{buf: buf}, &StarPipeWriter{buf: buf}, nil +} + +// Read reads buffered bytes from the pipe. +func (reader *StarPipeReader) Read(p []byte) (int, error) { + if reader == nil || reader.buf == nil { + return 0, io.ErrClosedPipe + } + return reader.buf.Read(p) +} + +// Close aborts the pipe from the read side and wakes blocked writers. +func (reader *StarPipeReader) Close() error { + if reader == nil || reader.buf == nil { + return io.ErrClosedPipe + } + return reader.buf.Abort() +} + +// Abort aborts the pipe from the read side and wakes blocked writers. +func (reader *StarPipeReader) Abort() error { + return reader.Close() +} + +// Write writes bytes into the pipe buffer. +func (writer *StarPipeWriter) Write(p []byte) (int, error) { + if writer == nil || writer.buf == nil { + return 0, io.ErrClosedPipe + } + return writer.buf.Write(p) +} + +// Close gracefully closes the write side. Buffered bytes remain readable until +// drained. +func (writer *StarPipeWriter) Close() error { + if writer == nil || writer.buf == nil { + return io.ErrClosedPipe + } + return writer.buf.Close() +} + +// Abort aborts the pipe immediately. +func (writer *StarPipeWriter) Abort() error { + if writer == nil || writer.buf == nil { + return io.ErrClosedPipe + } + return writer.buf.Abort() +} diff --git a/pipe_test.go b/pipe_test.go new file mode 100644 index 0000000..f97545d --- /dev/null +++ b/pipe_test.go @@ -0,0 +1,53 @@ +package stario + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestStarPipeRoundTrip(t *testing.T) { + reader, writer, err := NewStarPipe(16) + if err != nil { + t.Fatal(err) + } + want := []byte("hello world") + go func() { + _, _ = writer.Write(want) + _ = writer.Close() + }() + + got, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected payload: got %q want %q", got, want) + } +} + +func TestStarPipeReaderCloseAbortsWriter(t *testing.T) { + reader, writer, err := NewStarPipe(4) + if err != nil { + t.Fatal(err) + } + if err := reader.Close(); err != nil { + t.Fatal(err) + } + if _, err := writer.Write([]byte("x")); !errors.Is(err, ErrStarBufferClosed) { + t.Fatalf("unexpected writer error after reader close: %v", err) + } +} + +func TestStarPipeNilEndsReportClosedPipe(t *testing.T) { + var reader *StarPipeReader + var writer *StarPipeWriter + + if _, err := reader.Read(make([]byte, 1)); !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("unexpected nil reader error: %v", err) + } + if _, err := writer.Write([]byte("x")); !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("unexpected nil writer error: %v", err) + } +} diff --git a/que.go b/que.go deleted file mode 100644 index 8c02643..0000000 --- a/que.go +++ /dev/null @@ -1,325 +0,0 @@ -package stario - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "fmt" - "os" - "sync" - "time" -) - -var ErrDeadlineExceeded error = errors.New("deadline exceeded") - -// 识别头 -var header = []byte{11, 27, 19, 96, 12, 25, 02, 20} - -// MsgQueue 为基本的信息单位 -type MsgQueue struct { - ID uint16 - Msg []byte - Conn interface{} -} - -// StarQueue 为流数据中的消息队列分发 -type StarQueue struct { - maxLength uint32 - count int64 - Encode bool - msgID uint16 - msgPool chan MsgQueue - unFinMsg sync.Map - lastID int //= -1 - ctx context.Context - cancel context.CancelFunc - duration time.Duration - EncodeFunc func([]byte) []byte - DecodeFunc func([]byte) []byte - //restoreMu sync.Mutex -} - -func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue { - var q StarQueue - q.Encode = false - q.count = count - q.maxLength = maxMsgLength - q.msgPool = make(chan MsgQueue, count) - if ctx == nil { - q.ctx, q.cancel = context.WithCancel(context.Background()) - } else { - q.ctx, q.cancel = context.WithCancel(ctx) - } - q.duration = 0 - return &q -} -func NewQueueWithCount(count int64) *StarQueue { - return NewQueueCtx(nil, count, 0) -} - -// NewQueue 建立一个新消息队列 -func NewQueue() *StarQueue { - return NewQueueWithCount(32) -} - -// Uint32ToByte 4位uint32转byte -func Uint32ToByte(src uint32) []byte { - res := make([]byte, 4) - res[3] = uint8(src) - res[2] = uint8(src >> 8) - res[1] = uint8(src >> 16) - res[0] = uint8(src >> 24) - return res -} - -// ByteToUint32 byte转4位uint32 -func ByteToUint32(src []byte) uint32 { - var res uint32 - buffer := bytes.NewBuffer(src) - binary.Read(buffer, binary.BigEndian, &res) - return res -} - -// Uint16ToByte 2位uint16转byte -func Uint16ToByte(src uint16) []byte { - res := make([]byte, 2) - res[1] = uint8(src) - res[0] = uint8(src >> 8) - return res -} - -// ByteToUint16 用于byte转uint16 -func ByteToUint16(src []byte) uint16 { - var res uint16 - buffer := bytes.NewBuffer(src) - binary.Read(buffer, binary.BigEndian, &res) - return res -} - -// BuildMessage 生成编码后的信息用于发送 -func (q *StarQueue) BuildMessage(src []byte) []byte { - var buff bytes.Buffer - q.msgID++ - if q.Encode { - src = q.EncodeFunc(src) - } - length := uint32(len(src)) - buff.Write(header) - buff.Write(Uint32ToByte(length)) - buff.Write(Uint16ToByte(q.msgID)) - buff.Write(src) - return buff.Bytes() -} - -// BuildHeader 生成编码后的Header用于发送 -func (q *StarQueue) BuildHeader(length uint32) []byte { - var buff bytes.Buffer - q.msgID++ - buff.Write(header) - buff.Write(Uint32ToByte(length)) - buff.Write(Uint16ToByte(q.msgID)) - return buff.Bytes() -} - -type unFinMsg struct { - ID uint16 - LengthRecv uint32 - // HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id - HeaderMsg []byte - RecvMsg []byte -} - -func (q *StarQueue) push2list(msg MsgQueue) { - q.msgPool <- msg -} - -// ParseMessage 用于解析收到的msg信息 -func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error { - return q.parseMessage(msg, conn) -} - -// parseMessage 用于解析收到的msg信息 -func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error { - tmp, ok := q.unFinMsg.Load(conn) - if ok { //存在未完成的信息 - lastMsg := tmp.(*unFinMsg) - headerLen := len(lastMsg.HeaderMsg) - if headerLen < 14 { //未完成头标题 - //传输的数据不能填充header头 - if len(msg) < 14-headerLen { - //加入header头并退出 - lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg) - q.unFinMsg.Store(conn, lastMsg) - return nil - } - //获取14字节完整的header - header := msg[0 : 14-headerLen] - lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header) - //检查收到的header是否为认证header - //若不是,丢弃并重新来过 - if !checkHeader(lastMsg.HeaderMsg[0:8]) { - q.unFinMsg.Delete(conn) - if len(msg) == 0 { - return nil - } - return q.parseMessage(msg, conn) - } - //获得本数据包长度 - lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12]) - if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength { - q.unFinMsg.Delete(conn) - return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength) - } - //获得本数据包ID - lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14]) - //存入列表 - q.unFinMsg.Store(conn, lastMsg) - msg = msg[14-headerLen:] - if uint32(len(msg)) < lastMsg.LengthRecv { - lastMsg.RecvMsg = msg - q.unFinMsg.Store(conn, lastMsg) - return nil - } - if uint32(len(msg)) >= lastMsg.LengthRecv { - lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv] - if q.Encode { - lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg) - } - msg = msg[lastMsg.LengthRecv:] - storeMsg := MsgQueue{ - ID: lastMsg.ID, - Msg: lastMsg.RecvMsg, - Conn: conn, - } - //q.restoreMu.Lock() - q.push2list(storeMsg) - //q.restoreMu.Unlock() - q.unFinMsg.Delete(conn) - return q.parseMessage(msg, conn) - } - } else { - lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg) - if lastID < 0 { - q.unFinMsg.Delete(conn) - return q.parseMessage(msg, conn) - } - if len(msg) >= lastID { - lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID]) - if q.Encode { - lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg) - } - storeMsg := MsgQueue{ - ID: lastMsg.ID, - Msg: lastMsg.RecvMsg, - Conn: conn, - } - q.push2list(storeMsg) - q.unFinMsg.Delete(conn) - if len(msg) == lastID { - return nil - } - msg = msg[lastID:] - return q.parseMessage(msg, conn) - } - lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg) - q.unFinMsg.Store(conn, lastMsg) - return nil - } - } - if len(msg) == 0 { - return nil - } - var start int - if start = searchHeader(msg); start == -1 { - return errors.New("data format error") - } - msg = msg[start:] - lastMsg := unFinMsg{} - q.unFinMsg.Store(conn, &lastMsg) - return q.parseMessage(msg, conn) -} - -func checkHeader(msg []byte) bool { - if len(msg) != 8 { - return false - } - for k, v := range msg { - if v != header[k] { - return false - } - } - return true -} - -func searchHeader(msg []byte) int { - if len(msg) < 8 { - return 0 - } - for k, v := range msg { - find := 0 - if v == header[0] { - for k2, v2 := range header { - if msg[k+k2] == v2 { - find++ - } else { - break - } - } - if find == 8 { - return k - } - } - } - return -1 -} - -func bytesMerge(src ...[]byte) []byte { - var buff bytes.Buffer - for _, v := range src { - buff.Write(v) - } - return buff.Bytes() -} - -// Restore 获取收到的信息 -func (q *StarQueue) Restore() (MsgQueue, error) { - if q.duration.Seconds() == 0 { - q.duration = 86400 * time.Second - } - for { - select { - case <-q.ctx.Done(): - return MsgQueue{}, errors.New("Stoped By External Function Call") - case <-time.After(q.duration): - if q.duration != 0 { - return MsgQueue{}, ErrDeadlineExceeded - } - case data, ok := <-q.msgPool: - if !ok { - return MsgQueue{}, os.ErrClosed - } - return data, nil - } - } -} - -// RestoreOne 获取收到的一个信息 -// 兼容性修改 -func (q *StarQueue) RestoreOne() (MsgQueue, error) { - return q.Restore() -} - -// Stop 立即停止Restore -func (q *StarQueue) Stop() { - q.cancel() -} - -// RestoreDuration Restore最大超时时间 -func (q *StarQueue) RestoreDuration(tm time.Duration) { - q.duration = tm -} - -func (q *StarQueue) RestoreChan() <-chan MsgQueue { - return q.msgPool -} diff --git a/que_benchmark_test.go b/que_benchmark_test.go new file mode 100644 index 0000000..992430a --- /dev/null +++ b/que_benchmark_test.go @@ -0,0 +1,80 @@ +package stario + +import ( + "io" + "testing" +) + +func BenchmarkQueueBuildMessage64KiB(b *testing.B) { + que := NewQueue() + payload := make([]byte, 64*1024) + b.ReportAllocs() + b.SetBytes(int64(len(payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = que.BuildMessage(payload) + } +} + +func BenchmarkQueueWriteFrame64KiB(b *testing.B) { + que := NewQueue() + payload := make([]byte, 64*1024) + b.ReportAllocs() + b.SetBytes(int64(len(payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := que.WriteFrame(io.Discard, payload); err != nil { + b.Fatalf("WriteFrame failed: %v", err) + } + } +} + +func BenchmarkQueueWriteFrameBuffers64KiB(b *testing.B) { + que := NewQueue() + payload := make([]byte, 64*1024) + b.ReportAllocs() + b.SetBytes(int64(len(payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := que.WriteFrameBuffers(io.Discard, payload); err != nil { + b.Fatalf("WriteFrameBuffers failed: %v", err) + } + } +} + +func BenchmarkQueueWriteFramesBuffers4x64KiB(b *testing.B) { + que := NewQueue() + payloads := [][]byte{ + make([]byte, 64*1024), + make([]byte, 64*1024), + make([]byte, 64*1024), + make([]byte, 64*1024), + } + b.ReportAllocs() + b.SetBytes(int64(len(payloads) * len(payloads[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := que.WriteFramesBuffers(io.Discard, payloads...); err != nil { + b.Fatalf("WriteFramesBuffers failed: %v", err) + } + } +} + +func BenchmarkQueueParseRestoreHello(b *testing.B) { + que := NewQueueWithCount(1) + frame := que.BuildMessage([]byte("hello")) + if len(frame) == 0 { + b.Fatal("BuildMessage returned empty frame") + } + b.ReportAllocs() + b.SetBytes(int64(len(frame))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := que.ParseMessage(frame, "bench"); err != nil { + b.Fatal(err) + } + if _, err := que.Restore(); err != nil { + b.Fatal(err) + } + } +} diff --git a/que_bytes.go b/que_bytes.go new file mode 100644 index 0000000..3779bd5 --- /dev/null +++ b/que_bytes.go @@ -0,0 +1,34 @@ +package stario + +import "encoding/binary" + +// Uint32ToByte 4位uint32转byte。 +func Uint32ToByte(src uint32) []byte { + res := make([]byte, 4) + binary.BigEndian.PutUint32(res, src) + return res +} + +// ByteToUint32 byte转4位uint32。 +func ByteToUint32(src []byte) uint32 { + return binary.BigEndian.Uint32(src) +} + +// Uint16ToByte 2位uint16转byte。 +func Uint16ToByte(src uint16) []byte { + res := make([]byte, 2) + binary.BigEndian.PutUint16(res, src) + return res +} + +// ByteToUint16 用于byte转uint16。 +func ByteToUint16(src []byte) uint16 { + return binary.BigEndian.Uint16(src) +} + +func cloneBytes(src []byte) []byte { + if len(src) == 0 { + return nil + } + return append([]byte(nil), src...) +} diff --git a/que_frame.go b/que_frame.go new file mode 100644 index 0000000..3694c1c --- /dev/null +++ b/que_frame.go @@ -0,0 +1,208 @@ +package stario + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +// BuildMessage builds one full frame and panics if the payload is too large to +// fit in the framing format. +// +// New code should prefer BuildMessageErr when it needs an explicit error path. +func (q *StarQueue) BuildMessage(src []byte) []byte { + frame, err := q.BuildMessageErr(src) + if err != nil { + panic(err) + } + return frame +} + +// BuildMessageErr builds one full frame and returns an explicit error when the +// payload exceeds the 32-bit frame length. +func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) { + payload := src + if q.Encode && q.EncodeFunc != nil { + payload = q.EncodeFunc(payload) + } + length, err := payloadSizeToUint32(uint64(len(payload))) + if err != nil { + return nil, err + } + header := q.BuildHeader(length) + frame := make([]byte, 0, len(header)+len(payload)) + frame = append(frame, header...) + frame = append(frame, payload...) + return frame, nil +} + +// WriteFrame writes one framed payload directly to w without building an +// intermediate full-frame slice first. +func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error { + if w == nil { + return io.ErrClosedPipe + } + payload := src + if q.Encode && q.EncodeFunc != nil { + payload = q.EncodeFunc(payload) + } + length, err := payloadSizeToUint32(uint64(len(payload))) + if err != nil { + return err + } + var header [queHeaderSize]byte + writeHeaderBytes(header[:], queHeader{ + Length: length, + Version: queVersionV1, + Flags: queSupportedFlags, + }) + if err := writeFull(w, header[:]); err != nil { + return err + } + return writeFull(w, payload) +} + +// WriteFrameBuffers writes one framed payload using net.Buffers so callers can +// opt into gather writes when the underlying writer supports it well. +func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error { + if w == nil { + return io.ErrClosedPipe + } + payload := src + if q.Encode && q.EncodeFunc != nil { + payload = q.EncodeFunc(payload) + } + length, err := payloadSizeToUint32(uint64(len(payload))) + if err != nil { + return err + } + var header [queHeaderSize]byte + writeHeaderBytes(header[:], queHeader{ + Length: length, + Version: queVersionV1, + Flags: queSupportedFlags, + }) + buffers := net.Buffers{header[:], payload} + _, err = buffers.WriteTo(w) + return err +} + +// WriteFramesBuffers writes multiple framed payloads using one net.Buffers +// batch so callers can reduce write calls on stream transports. +func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error { + if w == nil { + return io.ErrClosedPipe + } + if len(payloads) == 0 { + return nil + } + buffers := make(net.Buffers, 0, len(payloads)*2) + headers := make([][queHeaderSize]byte, len(payloads)) + for i, src := range payloads { + payload := src + if q.Encode && q.EncodeFunc != nil { + payload = q.EncodeFunc(payload) + } + length, err := payloadSizeToUint32(uint64(len(payload))) + if err != nil { + return err + } + writeHeaderBytes(headers[i][:], queHeader{ + Length: length, + Version: queVersionV1, + Flags: queSupportedFlags, + }) + buffers = append(buffers, headers[i][:], payload) + } + _, err := buffers.WriteTo(w) + return err +} + +// BuildHeader 生成编码后的Header用于发送。 +func (q *StarQueue) BuildHeader(length uint32) []byte { + return buildHeaderBytes(queHeader{ + Length: length, + Version: queVersionV1, + Flags: queSupportedFlags, + }) +} + +func buildHeaderBytes(header queHeader) []byte { + buf := make([]byte, queHeaderSize) + writeHeaderBytes(buf, header) + return buf +} + +func writeHeaderBytes(dst []byte, header queHeader) { + if len(dst) < queHeaderSize { + return + } + copy(dst[:queMagicSize], queMagic) + binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length) + dst[12] = header.Version + dst[13] = header.Flags +} + +func payloadSizeToUint32(size uint64) (uint32, error) { + const maxFramePayload = ^uint32(0) + if size > uint64(maxFramePayload) { + return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload)) + } + return uint32(size), nil +} + +func writeFull(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if n > 0 { + data = data[n:] + } + if err != nil { + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} + +func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) { + if len(src) < queHeaderSize { + return queHeader{}, ErrQueueDataFormat + } + if !equalMagic(src[:queMagicSize]) { + return queHeader{}, ErrQueueDataFormat + } + + header := queHeader{ + Length: ByteToUint32(src[queMagicSize : queMagicSize+4]), + Version: src[12], + Flags: src[13], + } + + if header.Version != queVersionV1 { + return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version) + } + if header.Flags != queSupportedFlags { + return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags) + } + if maxLength != 0 && header.Length > maxLength { + return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength) + } + + return header, nil +} + +func equalMagic(src []byte) bool { + if len(src) != queMagicSize { + return false + } + for i, b := range src { + if b != queMagic[i] { + return false + } + } + return true +} diff --git a/que_parse.go b/que_parse.go new file mode 100644 index 0000000..d695839 --- /dev/null +++ b/que_parse.go @@ -0,0 +1,237 @@ +package stario + +import ( + "bytes" + "fmt" + "reflect" +) + +// ParseMessage 用于解析收到的msg信息。 +func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error { + return q.parseMessage(msg, conn, false, func(payload []byte) error { + return q.push2list(MsgQueue{ + Msg: payload, + Conn: conn, + }) + }) +} + +// ParseMessageView parses frames and exposes each payload to fn without +// forcing StarQueue to clone it first. +// +// The provided payload is only valid during the callback. If q uses the legacy +// DecodeFunc path, StarQueue still has to allocate a decoded payload first. +func (q *StarQueue) ParseMessageView(msg []byte, conn interface{}, fn func(FrameView) error) error { + if fn == nil { + return ErrQueueFrameHandlerNil + } + return q.parseMessage(msg, conn, true, func(payload []byte) error { + return fn(FrameView{ + Payload: payload, + Conn: conn, + }) + }) +} + +// ParseMessageOwned parses frames and emits owned payload copies to fn without +// routing them through RestoreChan. +// +// Compared with ParseMessage, this keeps StarQueue state handling the same but +// lets callers decide how to dispatch parsed messages themselves. +func (q *StarQueue) ParseMessageOwned(msg []byte, conn interface{}, fn func(MsgQueue) error) error { + if fn == nil { + return ErrQueueFrameHandlerNil + } + frames := make([]MsgQueue, 0, 1) + parseErr := q.parseMessage(msg, conn, false, func(payload []byte) error { + frames = append(frames, MsgQueue{ + Msg: payload, + Conn: conn, + }) + return nil + }) + for _, frame := range frames { + if err := fn(frame); err != nil { + if parseErr != nil { + return fmt.Errorf("%v: %w", parseErr, err) + } + return err + } + } + return parseErr +} + +func (q *StarQueue) parseMessage(msg []byte, conn interface{}, borrowPayload bool, emit func([]byte) error) error { + state, err := q.connState(conn) + if err != nil { + return err + } + var firstErr error + for { + payload, ok, err := q.nextPayload(state, conn, msg, borrowPayload) + msg = nil + if err != nil && firstErr == nil { + firstErr = err + } + if !ok { + break + } + if err := emit(payload); err != nil { + if firstErr != nil { + return fmt.Errorf("%v: %w", firstErr, err) + } + return err + } + } + return firstErr +} + +func (q *StarQueue) nextPayload(state *queConnState, conn interface{}, msg []byte, borrowPayload bool) ([]byte, bool, error) { + state.mu.Lock() + defer state.mu.Unlock() + + if len(msg) != 0 { + state.buf = append(state.buf, msg...) + } + + var firstErr error + for { + synced, err := syncFrameStart(&state.buf) + if err != nil && firstErr == nil { + firstErr = err + } + if !synced { + if len(state.buf) == 0 { + q.states.Delete(conn) + } + return nil, false, firstErr + } + if len(state.buf) < queHeaderSize { + return nil, false, firstErr + } + + header, err := parseHeaderBytes(state.buf[:queHeaderSize], q.maxLength) + if err != nil { + if firstErr == nil { + firstErr = err + } + state.buf = shrinkBuffer(state.buf[1:]) + continue + } + + frameLen := queHeaderSize + int(header.Length) + if len(state.buf) < frameLen { + return nil, false, firstErr + } + + payload, rest := extractPayload(state.buf, frameLen, borrowPayload && !(q.Encode && q.DecodeFunc != nil)) + state.buf = rest + if q.Encode && q.DecodeFunc != nil { + payload = q.DecodeFunc(payload) + } + if len(state.buf) == 0 { + q.states.Delete(conn) + } + return payload, true, firstErr + } +} + +func extractPayload(buf []byte, frameLen int, borrowPayload bool) ([]byte, []byte) { + payload := buf[queHeaderSize:frameLen] + if !borrowPayload { + return cloneBytes(payload), shrinkBuffer(buf[frameLen:]) + } + if frameLen == len(buf) { + return payload, nil + } + return payload, cloneBytes(buf[frameLen:]) +} + +func (q *StarQueue) push2list(msg MsgQueue) error { + select { + case <-q.ctx.Done(): + return q.ctx.Err() + default: + } + q.sendMu.RLock() + defer q.sendMu.RUnlock() + select { + case <-q.ctx.Done(): + return q.ctx.Err() + case q.msgPool <- msg: + return nil + } +} + +func validateConnKey(conn interface{}) error { + if conn == nil { + return ErrQueueConnKeyNil + } + typ := reflect.TypeOf(conn) + if typ != nil && !typ.Comparable() { + return ErrQueueConnKeyInvalid + } + return nil +} + +func (q *StarQueue) connState(conn interface{}) (*queConnState, error) { + if err := validateConnKey(conn); err != nil { + return nil, err + } + state, _ := q.states.LoadOrStore(conn, &queConnState{}) + return state.(*queConnState), nil +} + +func syncFrameStart(buf *[]byte) (bool, error) { + if len(*buf) == 0 { + return false, nil + } + + if len(*buf) >= queMagicSize && equalMagic((*buf)[:queMagicSize]) { + return true, nil + } + + idx := bytes.Index(*buf, queMagic) + if idx == 0 { + return true, nil + } + if idx > 0 { + *buf = cloneBytes((*buf)[idx:]) + return true, ErrQueueDataFormat + } + + keep := trailingMagicPrefixLen(*buf) + if keep == len(*buf) { + return false, nil + } + if keep > 0 { + *buf = cloneBytes((*buf)[len(*buf)-keep:]) + return false, ErrQueueDataFormat + } + + *buf = (*buf)[:0] + return false, ErrQueueDataFormat +} + +func trailingMagicPrefixLen(buf []byte) int { + max := len(buf) + if max > queMagicSize-1 { + max = queMagicSize - 1 + } + for keep := max; keep > 0; keep-- { + if bytes.Equal(buf[len(buf)-keep:], queMagic[:keep]) { + return keep + } + } + return 0 +} + +func shrinkBuffer(buf []byte) []byte { + if len(buf) == 0 { + return nil + } + if cap(buf) > len(buf)*4 { + return cloneBytes(buf) + } + return buf +} diff --git a/que_runtime.go b/que_runtime.go new file mode 100644 index 0000000..2d1af68 --- /dev/null +++ b/que_runtime.go @@ -0,0 +1,80 @@ +package stario + +import ( + "os" + "time" +) + +func (q *StarQueue) closedRestoreErr() error { + if err := q.ctx.Err(); err != nil { + return err + } + return os.ErrClosed +} + +// Restore blocks until one message is available, the queue is stopped, or the +// configured timeout expires. +func (q *StarQueue) Restore() (MsgQueue, error) { + if q.duration <= 0 { + select { + case <-q.ctx.Done(): + return MsgQueue{}, q.ctx.Err() + case data, ok := <-q.msgPool: + if !ok { + return MsgQueue{}, q.closedRestoreErr() + } + return data, nil + } + } + timer := time.NewTimer(q.duration) + defer timer.Stop() + select { + case <-q.ctx.Done(): + return MsgQueue{}, q.ctx.Err() + case <-timer.C: + return MsgQueue{}, ErrDeadlineExceeded + case data, ok := <-q.msgPool: + if !ok { + return MsgQueue{}, q.closedRestoreErr() + } + return data, nil + } +} + +// RestoreOne 获取收到的一个信息。 +// 兼容性修改。 +func (q *StarQueue) RestoreOne() (MsgQueue, error) { + return q.Restore() +} + +// Stop cancels the queue runtime. +// +// After Stop returns, Restore unblocks with context.Canceled and RestoreChan is +// eventually closed. +func (q *StarQueue) Stop() { + q.cancel() + q.shutdown() +} + +func (q *StarQueue) shutdown() { + q.stopOnce.Do(func() { + q.sendMu.Lock() + close(q.msgPool) + q.sendMu.Unlock() + }) +} + +// RestoreDuration sets the Restore timeout. A non-positive duration means wait +// forever until a message arrives or Stop is called. +func (q *StarQueue) RestoreDuration(tm time.Duration) { + q.duration = tm +} + +// RestoreChan exposes the parsed message stream. +// +// The returned channel is closed after Stop or when the queue context is +// canceled. New code should still prefer Restore when it needs timeout/error +// classification instead of a plain stream. +func (q *StarQueue) RestoreChan() <-chan MsgQueue { + return q.msgPool +} diff --git a/que_test.go b/que_test.go index d4a693a..1fb12a3 100644 --- a/que_test.go +++ b/que_test.go @@ -1,42 +1,520 @@ package stario import ( - "fmt" + "bytes" + "context" + "errors" + "io" "testing" "time" ) -func Test_QueSpeed(t *testing.T) { - que := NewQueueWithCount(0) - stop := make(chan struct{}, 1) - que.RestoreDuration(time.Second * 10) - var count int64 - go func() { - for { - select { - case <-stop: - //fmt.Println(count) - return - default: - } - _, err := que.RestoreOne() - if err == nil { - count++ - } - } - }() - cp := 0 - stoped := time.After(time.Second * 10) - data := que.BuildMessage([]byte("hello")) - for { - select { - case <-stoped: - fmt.Println(count, cp) - stop <- struct{}{} - return - default: - que.ParseMessage(data, "lala") - cp++ - } +func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + + if len(frame) != queHeaderSize+5 { + t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5) + } + if !bytes.Equal(frame[:queMagicSize], queMagic) { + t.Fatalf("unexpected magic: %v", frame[:queMagicSize]) + } + if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 { + t.Fatalf("unexpected payload length: got %d want 5", got) + } + if frame[12] != queVersionV1 { + t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1) + } + if frame[13] != queSupportedFlags { + t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags) + } + if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) { + t.Fatalf("unexpected payload: %q", frame[queHeaderSize:]) + } +} + +func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) { + que := NewQueue() + want := que.BuildMessage([]byte("hello")) + + var buf bytes.Buffer + if err := que.WriteFrame(&buf, []byte("hello")); err != nil { + t.Fatalf("WriteFrame failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFrame mismatch: got %v want %v", got, want) + } +} + +func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) { + que := NewQueue() + want := que.BuildMessage([]byte("hello")) + + var buf bytes.Buffer + if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil { + t.Fatalf("WriteFrameBuffers failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want) + } +} + +func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) { + que := NewQueue() + payloads := [][]byte{ + []byte("hello"), + []byte("world"), + []byte("batch"), + } + + var want []byte + for _, payload := range payloads { + want = append(want, que.BuildMessage(payload)...) + } + + var buf bytes.Buffer + if err := que.WriteFramesBuffers(&buf, payloads...); err != nil { + t.Fatalf("WriteFramesBuffers failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want) + } +} + +func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) { + que := NewQueue() + que.Encode = true + que.EncodeFunc = bytes.ToUpper + want := que.BuildMessage([]byte("hello")) + + var buf bytes.Buffer + if err := que.WriteFrame(&buf, []byte("hello")); err != nil { + t.Fatalf("WriteFrame failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want) + } +} + +func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) { + que := NewQueue() + que.Encode = true + que.EncodeFunc = bytes.ToUpper + want := que.BuildMessage([]byte("hello")) + + var buf bytes.Buffer + if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil { + t.Fatalf("WriteFrameBuffers failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want) + } +} + +func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) { + que := NewQueue() + que.Encode = true + que.EncodeFunc = bytes.ToUpper + payloads := [][]byte{ + []byte("hello"), + []byte("batch"), + } + + var want []byte + for _, payload := range payloads { + want = append(want, que.BuildMessage(payload)...) + } + + var buf bytes.Buffer + if err := que.WriteFramesBuffers(&buf, payloads...); err != nil { + t.Fatalf("WriteFramesBuffers failed: %v", err) + } + if got := buf.Bytes(); !bytes.Equal(got, want) { + t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want) + } +} + +func TestQueueParseMessageSplitAcrossCalls(t *testing.T) { + que := NewQueueWithCount(1) + frame := que.BuildMessage([]byte("hello")) + + if err := que.ParseMessage(frame[:3], "split"); err != nil { + t.Fatalf("unexpected error on partial magic: %v", err) + } + if err := que.ParseMessage(frame[3:11], "split"); err != nil { + t.Fatalf("unexpected error on partial header: %v", err) + } + if err := que.ParseMessage(frame[11:], "split"); err != nil { + t.Fatalf("unexpected error on payload completion: %v", err) + } + + select { + case data := <-que.RestoreChan(): + if data.ID != 0 { + t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID) + } + if data.Conn != "split" { + t.Fatalf("unexpected conn: %#v", data.Conn) + } + if !bytes.Equal(data.Msg, []byte("hello")) { + t.Fatalf("unexpected payload: %q", data.Msg) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("did not restore parsed frame") + } +} + +func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + var got [][]byte + + handler := func(view FrameView) error { + got = append(got, cloneBytes(view.Payload)) + if view.Conn != "split-view" { + t.Fatalf("unexpected conn: %#v", view.Conn) + } + return nil + } + + if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil { + t.Fatalf("unexpected error on partial magic: %v", err) + } + if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil { + t.Fatalf("unexpected error on partial header: %v", err) + } + if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil { + t.Fatalf("unexpected error on payload completion: %v", err) + } + + if len(got) != 1 { + t.Fatalf("unexpected frame count: got %d want 1", len(got)) + } + if !bytes.Equal(got[0], []byte("hello")) { + t.Fatalf("unexpected payload: %q", got[0]) + } +} + +func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + var got []MsgQueue + + handler := func(msg MsgQueue) error { + got = append(got, MsgQueue{ + Msg: cloneBytes(msg.Msg), + Conn: msg.Conn, + }) + return nil + } + + if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil { + t.Fatalf("unexpected error on partial magic: %v", err) + } + if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil { + t.Fatalf("unexpected error on partial header: %v", err) + } + if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil { + t.Fatalf("unexpected error on payload completion: %v", err) + } + + if len(got) != 1 { + t.Fatalf("unexpected frame count: got %d want 1", len(got)) + } + if got[0].Conn != "split-owned" { + t.Fatalf("unexpected conn: %#v", got[0].Conn) + } + if !bytes.Equal(got[0].Msg, []byte("hello")) { + t.Fatalf("unexpected payload: %q", got[0].Msg) + } + select { + case msg := <-que.RestoreChan(): + t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg) + default: + } +} + +func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) { + que := NewQueueWithCount(1) + frame := que.BuildMessage([]byte("hello")) + + err := que.ParseMessage(append([]byte("junk"), frame...), "garbage") + if !errors.Is(err, ErrQueueDataFormat) { + t.Fatalf("expected data format error, got %v", err) + } + + select { + case data := <-que.RestoreChan(): + if !bytes.Equal(data.Msg, []byte("hello")) { + t.Fatalf("unexpected payload after resync: %q", data.Msg) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("did not restore frame after skipping garbage") + } +} + +func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + var got []byte + + err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error { + got = cloneBytes(view.Payload) + return nil + }) + if !errors.Is(err, ErrQueueDataFormat) { + t.Fatalf("expected data format error, got %v", err) + } + if !bytes.Equal(got, []byte("hello")) { + t.Fatalf("unexpected payload after resync: %q", got) + } +} + +func TestQueueParseMessageViewNilHandler(t *testing.T) { + que := NewQueue() + err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil) + if !errors.Is(err, ErrQueueFrameHandlerNil) { + t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil) + } +} + +func TestQueueParseMessageOwnedNilHandler(t *testing.T) { + que := NewQueue() + err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil) + if !errors.Is(err, ErrQueueFrameHandlerNil) { + t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil) + } +} + +func TestQueueWriteFrameNilWriter(t *testing.T) { + que := NewQueue() + err := que.WriteFrame(nil, []byte("hello")) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestQueueWriteFrameBuffersNilWriter(t *testing.T) { + que := NewQueue() + err := que.WriteFrameBuffers(nil, []byte("hello")) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestQueueWriteFramesBuffersNilWriter(t *testing.T) { + que := NewQueue() + err := que.WriteFramesBuffers(nil, []byte("hello")) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) { + que := NewQueueWithCount(1) + frame := que.BuildMessage([]byte("hello")) + frame[12] = 2 + + err := que.ParseMessage(frame, "version") + if !errors.Is(err, ErrQueueUnsupportedVersion) { + t.Fatalf("expected unsupported version error, got %v", err) + } + + select { + case data := <-que.RestoreChan(): + t.Fatalf("unexpected restored frame: %#v", data) + default: + } +} + +func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) { + que := NewQueueCtx(nil, 1, 4) + frame := que.BuildMessage([]byte("hello")) + + err := que.ParseMessage(frame, "large") + if !errors.Is(err, ErrQueueMessageTooLarge) { + t.Fatalf("expected message too large error, got %v", err) + } + + select { + case data := <-que.RestoreChan(): + t.Fatalf("unexpected restored frame: %#v", data) + default: + } +} + +func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + err := que.ParseMessage(frame, []byte("not-comparable")) + if !errors.Is(err, ErrQueueConnKeyInvalid) { + t.Fatalf("expected invalid conn key error, got %v", err) + } +} + +func TestQueueParseMessageRejectsNilConnKey(t *testing.T) { + que := NewQueue() + frame := que.BuildMessage([]byte("hello")) + err := que.ParseMessage(frame, nil) + if !errors.Is(err, ErrQueueConnKeyNil) { + t.Fatalf("expected nil conn key error, got %v", err) + } +} + +func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) { + _, err := payloadSizeToUint32(uint64(^uint32(0)) + 1) + if !errors.Is(err, ErrQueueMessageTooLarge) { + t.Fatalf("expected message too large error, got %v", err) + } +} + +func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) { + que := NewQueueWithCount(1) + que.RestoreDuration(0) + + type restoreResult struct { + msg MsgQueue + err error + } + resultCh := make(chan restoreResult, 1) + go func() { + msg, err := que.Restore() + resultCh <- restoreResult{msg: msg, err: err} + }() + + select { + case result := <-resultCh: + t.Fatalf("Restore returned too early: %#v", result) + case <-time.After(50 * time.Millisecond): + } + + if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil { + t.Fatalf("ParseMessage failed: %v", err) + } + + select { + case result := <-resultCh: + if result.err != nil { + t.Fatalf("Restore returned error: %v", result.err) + } + if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) { + t.Fatalf("unexpected restore result: %#v", result.msg) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("Restore did not return after message arrival") + } +} + +func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) { + que := NewQueue() + resultCh := make(chan error, 1) + go func() { + _, err := que.Restore() + resultCh <- err + }() + + que.Stop() + + select { + case err := <-resultCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("Restore did not return after Stop") + } +} + +func TestQueueRestoreChanClosesOnStop(t *testing.T) { + que := NewQueue() + resultCh := make(chan bool, 1) + go func() { + _, ok := <-que.RestoreChan() + resultCh <- ok + }() + + que.Stop() + + select { + case ok := <-resultCh: + if ok { + t.Fatal("expected RestoreChan to close after Stop") + } + case <-time.After(200 * time.Millisecond): + t.Fatal("RestoreChan did not close after Stop") + } +} + +func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + que := NewQueueCtx(ctx, 1, 0) + resultCh := make(chan bool, 1) + go func() { + _, ok := <-que.RestoreChan() + resultCh <- ok + }() + + cancel() + + select { + case ok := <-resultCh: + if ok { + t.Fatal("expected RestoreChan to close after context cancel") + } + case <-time.After(200 * time.Millisecond): + t.Fatal("RestoreChan did not close after context cancel") + } +} + +func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) { + que := NewQueueWithCount(1) + if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil { + t.Fatalf("ParseMessage first failed: %v", err) + } + + errCh := make(chan error, 1) + go func() { + errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full") + }() + + select { + case err := <-errCh: + t.Fatalf("ParseMessage returned before Stop: %v", err) + case <-time.After(50 * time.Millisecond): + } + + que.Stop() + + select { + case err := <-errCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("ParseMessage did not return after Stop") + } +} + +func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) { + que := NewQueue() + reentered := false + + err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error { + if !bytes.Equal(view.Payload, []byte("outer")) { + t.Fatalf("unexpected outer payload: %q", view.Payload) + } + return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error { + reentered = true + if !bytes.Equal(inner.Payload, []byte("inner")) { + t.Fatalf("unexpected inner payload: %q", inner.Payload) + } + return nil + }) + }) + if err != nil { + t.Fatalf("ParseMessageView failed: %v", err) + } + if !reentered { + t.Fatal("expected reentrant ParseMessageView to run") } } diff --git a/que_types.go b/que_types.go new file mode 100644 index 0000000..48e6233 --- /dev/null +++ b/que_types.go @@ -0,0 +1,101 @@ +package stario + +import ( + "context" + "errors" + "sync" + "time" +) + +var ErrDeadlineExceeded = errors.New("deadline exceeded") +var ErrQueueDataFormat = errors.New("data format error") +var ErrQueueUnsupportedVersion = errors.New("unsupported frame version") +var ErrQueueUnsupportedFlags = errors.New("unsupported frame flags") +var ErrQueueMessageTooLarge = errors.New("message too large") +var ErrQueueFrameHandlerNil = errors.New("frame handler is nil") +var ErrQueueConnKeyInvalid = errors.New("queue conn key must be comparable") +var ErrQueueConnKeyNil = errors.New("queue conn key must not be nil") + +const ( + queMagicSize = 8 + queHeaderSize = 14 + queVersionV1 = 1 + queSupportedFlags = 0 +) + +// 识别头 +var queMagic = []byte{11, 27, 19, 96, 12, 25, 02, 20} + +// MsgQueue 为基本的信息单位。 +type MsgQueue struct { + // Deprecated: frame-level IDs are no longer emitted by StarQueue v2 framing. + ID uint16 + + Msg []byte + Conn interface{} +} + +// FrameView exposes a parsed payload without forcing StarQueue to clone it. +// +// The payload is only valid during the ParseMessageView callback. Callers must +// copy it if they need to keep it after the callback returns. +type FrameView struct { + Payload []byte + Conn interface{} +} + +// StarQueue 为流数据中的消息队列分发。 +type StarQueue struct { + maxLength uint32 + msgPool chan MsgQueue + states sync.Map + ctx context.Context + cancel context.CancelFunc + duration time.Duration + sendMu sync.RWMutex + stopOnce sync.Once + + // Deprecated: new code should keep StarQueue focused on framing only. + Encode bool + // Deprecated: new code should keep StarQueue focused on framing only. + EncodeFunc func([]byte) []byte + // Deprecated: new code should keep StarQueue focused on framing only. + DecodeFunc func([]byte) []byte +} + +type queHeader struct { + Length uint32 + Version uint8 + Flags uint8 +} + +type queConnState struct { + mu sync.Mutex + buf []byte +} + +func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue { + if count < 0 { + panic("stario: negative queue count") + } + q := &StarQueue{ + maxLength: maxMsgLength, + msgPool: make(chan MsgQueue, count), + } + if ctx == nil { + q.ctx, q.cancel = context.WithCancel(context.Background()) + } else { + q.ctx, q.cancel = context.WithCancel(ctx) + } + context.AfterFunc(q.ctx, q.shutdown) + return q +} + +func NewQueueWithCount(count int64) *StarQueue { + return NewQueueCtx(nil, count, 0) +} + +// NewQueue 建立一个新消息队列。 +func NewQueue() *StarQueue { + return NewQueueWithCount(32) +} diff --git a/signal_error.go b/signal_error.go new file mode 100644 index 0000000..687cbb0 --- /dev/null +++ b/signal_error.go @@ -0,0 +1,44 @@ +package stario + +import "errors" + +var ( + ErrSignalInterrupt = errors.New("interrupt") + ErrSignalStop = errors.New("stop") + ErrSignalQuit = errors.New("quit") +) + +type inputSignalError struct { + msg string + cause error +} + +func (e *inputSignalError) Error() string { + return e.msg +} + +func (e *inputSignalError) Unwrap() error { + return e.cause +} + +func signalErrorForType(sigtype rune) error { + switch sigtype { + case 0x03: + return &inputSignalError{ + msg: "SIGNAL SIGINT RECIVED", + cause: ErrSignalInterrupt, + } + case 0x1a: + return &inputSignalError{ + msg: "SIGNAL SIGSTOP RECIVED", + cause: ErrSignalStop, + } + case 0x1c: + return &inputSignalError{ + msg: "SIGNAL SIGQUIT RECIVED", + cause: ErrSignalQuit, + } + default: + return nil + } +} diff --git a/signal_error_test.go b/signal_error_test.go new file mode 100644 index 0000000..718cdc7 --- /dev/null +++ b/signal_error_test.go @@ -0,0 +1,50 @@ +package stario + +import ( + "errors" + "testing" +) + +func TestSignalErrorForType(t *testing.T) { + cases := []struct { + name string + sig rune + msg string + want error + }{ + {name: "interrupt", sig: 0x03, msg: "SIGNAL SIGINT RECIVED", want: ErrSignalInterrupt}, + {name: "stop", sig: 0x1a, msg: "SIGNAL SIGSTOP RECIVED", want: ErrSignalStop}, + {name: "quit", sig: 0x1c, msg: "SIGNAL SIGQUIT RECIVED", want: ErrSignalQuit}, + } + for _, tc := range cases { + err := signalErrorForType(tc.sig) + if err == nil { + t.Fatalf("%s: expected non-nil error", tc.name) + } + if err.Error() != tc.msg { + t.Fatalf("%s: unexpected error text: got %q want %q", tc.name, err.Error(), tc.msg) + } + if !errors.Is(err, tc.want) { + t.Fatalf("%s: errors.Is mismatch: got %v want %v", tc.name, err, tc.want) + } + } +} + +func TestSignalErrorForUnknownType(t *testing.T) { + if err := signalErrorForType('x'); err != nil { + t.Fatalf("expected nil error for unknown signal, got %v", err) + } +} + +func TestSignalReturnsTypedError(t *testing.T) { + err := signal(0x03) + if !errors.Is(err, ErrSignalInterrupt) { + t.Fatalf("expected interrupt error, got %v", err) + } +} + +func TestSignalReturnsNilForUnknownType(t *testing.T) { + if err := signal('x'); err != nil { + t.Fatalf("expected nil error for unknown signal, got %v", err) + } +} diff --git a/signal_windows.go b/signal_windows.go index 9f955f4..5ce21dd 100644 --- a/signal_windows.go +++ b/signal_windows.go @@ -3,20 +3,6 @@ package stario -import ( - "errors" -) - func signal(sigtype rune) error { - //todo: use win32api call signal - switch sigtype { - case 0x03: - return errors.New("SIGNAL SIGINT RECIVED") - case 0x1a: - return errors.New("SIGNAL SIGSTOP RECIVED") - case 0x1c: - return errors.New("SIGNAL SIGQUIT RECIVED") - default: - return nil - } + return signalErrorForType(sigtype) } diff --git a/singal_other.go b/singal_other.go index d300831..253f810 100644 --- a/singal_other.go +++ b/singal_other.go @@ -3,24 +3,6 @@ package stario -import ( - "errors" - "os" - "syscall" -) - func signal(sigtype rune) error { - switch sigtype { - case 0x03: - syscall.Kill(os.Getpid(), syscall.SIGINT) - return errors.New("SIGNAL SIGINT RECIVED") - case 0x1a: - syscall.Kill(os.Getpid(), syscall.SIGSTOP) - return errors.New("SIGNAL SIGSTOP RECIVED") - case 0x1c: - syscall.Kill(os.Getpid(), syscall.SIGQUIT) - return errors.New("SIGNAL SIGQUIT RECIVED") - default: - return nil - } + return signalErrorForType(sigtype) } diff --git a/starring.go b/starring.go new file mode 100644 index 0000000..a3ef1e0 --- /dev/null +++ b/starring.go @@ -0,0 +1,123 @@ +package stario + +import ( + "container/list" + "errors" + "sync" +) + +var ErrStarRingInvalidCapacity = errors.New("star ring capacity must be greater than zero") + +// StarRing keeps the newest bytes in memory with fixed capacity. +// It never blocks writers: when capacity is exceeded, oldest bytes are dropped. +type StarRing struct { + mu sync.RWMutex + cap int + + size int + chunks list.List // each element: []byte +} + +func NewStarRing(capacity int) (*StarRing, error) { + if capacity <= 0 { + return nil, ErrStarRingInvalidCapacity + } + return &StarRing{cap: capacity}, nil +} + +func (s *StarRing) Capacity() int { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cap +} + +func (s *StarRing) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return s.size +} + +func (s *StarRing) Reset() { + s.mu.Lock() + defer s.mu.Unlock() + s.size = 0 + s.chunks.Init() +} + +func (s *StarRing) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + payload := append([]byte(nil), p...) + + s.mu.Lock() + defer s.mu.Unlock() + s.writeLocked(payload) + return len(p), nil +} + +func (s *StarRing) WriteString(text string) (int, error) { + if len(text) == 0 { + return 0, nil + } + return s.Write([]byte(text)) +} + +func (s *StarRing) Snapshot() []byte { + s.mu.RLock() + defer s.mu.RUnlock() + + out := make([]byte, s.size) + pos := 0 + for node := s.chunks.Front(); node != nil; node = node.Next() { + chunk, ok := node.Value.([]byte) + if !ok || len(chunk) == 0 { + continue + } + pos += copy(out[pos:], chunk) + } + return out +} + +func (s *StarRing) writeLocked(payload []byte) { + if len(payload) >= s.cap { + payload = payload[len(payload)-s.cap:] + s.chunks.Init() + s.chunks.PushBack(payload) + s.size = len(payload) + return + } + + s.chunks.PushBack(payload) + s.size += len(payload) + s.trimLocked() +} + +func (s *StarRing) trimLocked() { + if s.size <= s.cap { + return + } + overflow := s.size - s.cap + for overflow > 0 { + front := s.chunks.Front() + if front == nil { + s.size = 0 + return + } + head, ok := front.Value.([]byte) + if !ok || len(head) == 0 { + s.chunks.Remove(front) + continue + } + if len(head) <= overflow { + s.chunks.Remove(front) + s.size -= len(head) + overflow -= len(head) + continue + } + trimmed := append([]byte(nil), head[overflow:]...) + front.Value = trimmed + s.size -= overflow + overflow = 0 + } +} diff --git a/starring_test.go b/starring_test.go new file mode 100644 index 0000000..394e131 --- /dev/null +++ b/starring_test.go @@ -0,0 +1,104 @@ +package stario + +import ( + "bytes" + "testing" +) + +func TestNewStarRingRejectsInvalidCapacity(t *testing.T) { + ring, err := NewStarRing(0) + if err != ErrStarRingInvalidCapacity { + t.Fatalf("unexpected error: %v", err) + } + if ring != nil { + t.Fatal("expected nil ring when capacity is invalid") + } +} + +func TestStarRingWriteAndSnapshot(t *testing.T) { + ring, err := NewStarRing(16) + if err != nil { + t.Fatal(err) + } + if n, err := ring.Write([]byte("abc")); err != nil || n != 3 { + t.Fatalf("write failed: n=%d err=%v", n, err) + } + if n, err := ring.WriteString("def"); err != nil || n != 3 { + t.Fatalf("write string failed: n=%d err=%v", n, err) + } + + got := ring.Snapshot() + if !bytes.Equal(got, []byte("abcdef")) { + t.Fatalf("unexpected snapshot: %q", got) + } + if ring.Len() != 6 { + t.Fatalf("unexpected ring len: %d", ring.Len()) + } + if ring.Capacity() != 16 { + t.Fatalf("unexpected ring capacity: %d", ring.Capacity()) + } +} + +func TestStarRingOverflowDropsOldestBytesPrecisely(t *testing.T) { + ring, err := NewStarRing(10) + if err != nil { + t.Fatal(err) + } + _, _ = ring.WriteString("abcde") + _, _ = ring.WriteString("fghij") + _, _ = ring.WriteString("kl") + + got := ring.Snapshot() + want := []byte("cdefghijkl") + if !bytes.Equal(got, want) { + t.Fatalf("unexpected snapshot after overflow: got %q want %q", got, want) + } + if ring.Len() != len(want) { + t.Fatalf("unexpected ring len after overflow: %d", ring.Len()) + } +} + +func TestStarRingLargeWriteKeepsTail(t *testing.T) { + ring, err := NewStarRing(5) + if err != nil { + t.Fatal(err) + } + _, _ = ring.WriteString("123") + _, _ = ring.WriteString("456789") + + got := ring.Snapshot() + want := []byte("56789") + if !bytes.Equal(got, want) { + t.Fatalf("unexpected snapshot after large write: got %q want %q", got, want) + } +} + +func TestStarRingSnapshotIsCopy(t *testing.T) { + ring, err := NewStarRing(8) + if err != nil { + t.Fatal(err) + } + _, _ = ring.WriteString("hello") + + first := ring.Snapshot() + first[0] = 'H' + second := ring.Snapshot() + if !bytes.Equal(second, []byte("hello")) { + t.Fatalf("snapshot should not expose internal storage: %q", second) + } +} + +func TestStarRingReset(t *testing.T) { + ring, err := NewStarRing(8) + if err != nil { + t.Fatal(err) + } + _, _ = ring.WriteString("hello") + ring.Reset() + if ring.Len() != 0 { + t.Fatalf("expected len=0 after reset, got %d", ring.Len()) + } + if got := ring.Snapshot(); len(got) != 0 { + t.Fatalf("expected empty snapshot after reset, got %q", got) + } +} diff --git a/sync.go b/sync.go index 4ca92a0..2be6073 100644 --- a/sync.go +++ b/sync.go @@ -12,6 +12,10 @@ const ( waitGroupAddModeLoose ) +// WaitGroup is a concurrency-limited sync.WaitGroup variant. +// +// A zero or negative limit means unlimited concurrency. WaitGroup must not be +// copied after first use. type WaitGroup struct { wg sync.WaitGroup mu sync.Mutex @@ -22,6 +26,7 @@ type WaitGroup struct { addMode waitGroupAddMode } +// NewWaitGroup creates a WaitGroup with the provided concurrency limit. func NewWaitGroup(maxCount int) WaitGroup { if maxCount < 0 { panic("stario: negative max wait count") @@ -38,6 +43,10 @@ func (w *WaitGroup) init() { }) } +// Add adjusts the running task count. +// +// Positive deltas may block when the concurrency limit is reached. Negative +// deltas release running slots. func (w *WaitGroup) Add(delta int) { w.init() if delta == 0 { @@ -87,10 +96,12 @@ func (w *WaitGroup) release(delta int) { w.cond.Broadcast() } +// Done releases one running task slot. func (w *WaitGroup) Done() { w.Add(-1) } +// Go runs fn in a goroutine while accounting for the concurrency limit. func (w *WaitGroup) Go(fn func()) { w.Add(1) go func() { @@ -99,11 +110,13 @@ func (w *WaitGroup) Go(fn func()) { }() } +// Wait blocks until all added work has completed. func (w *WaitGroup) Wait() { w.init() w.wg.Wait() } +// GetMaxWaitNum returns the current concurrency limit. func (w *WaitGroup) GetMaxWaitNum() int { w.init() w.mu.Lock() @@ -111,6 +124,7 @@ func (w *WaitGroup) GetMaxWaitNum() int { return w.maxCount } +// SetMaxWaitNum updates the concurrency limit. func (w *WaitGroup) SetMaxWaitNum(num int) { if num < 0 { panic("stario: negative max wait count") @@ -122,6 +136,8 @@ func (w *WaitGroup) SetMaxWaitNum(num int) { w.cond.Broadcast() } +// SetStrictAddMode controls whether Add(n>1) panics or auto-expands the limit +// when the requested batch exceeds the current capacity. func (w *WaitGroup) SetStrictAddMode(strict bool) { w.init() w.mu.Lock() @@ -134,6 +150,7 @@ func (w *WaitGroup) SetStrictAddMode(strict bool) { w.cond.Broadcast() } +// StrictAddMode reports whether strict batch-add behavior is enabled. func (w *WaitGroup) StrictAddMode() bool { w.init() w.mu.Lock() diff --git a/terminal.go b/terminal.go new file mode 100644 index 0000000..0374a96 --- /dev/null +++ b/terminal.go @@ -0,0 +1,73 @@ +package stario + +import ( + "context" + "errors" + "os" + + "golang.org/x/term" +) + +// ErrTerminalNotTTY reports that terminal-only input was requested from a +// non-terminal stdin. +var ErrTerminalNotTTY = errors.New("terminal input requires a tty") + +var terminalCheckFunc = term.IsTerminal + +// IsTerminal reports whether os.Stdin is attached to a terminal. +func IsTerminal() bool { + return IsTerminalFile(os.Stdin) +} + +// IsTerminalFD reports whether fd is attached to a terminal. +func IsTerminalFD(fd int) bool { + return terminalCheckFunc(fd) +} + +// IsTerminalFile reports whether file is attached to a terminal. +func IsTerminalFile(file *os.File) bool { + if file == nil { + return false + } + return IsTerminalFD(int(file.Fd())) +} + +// ReadPasswordContext reads one password-style line from the terminal. +// +// It returns ErrTerminalNotTTY when stdin is not a terminal. If ctx is canceled +// while waiting for input, this function returns ctx.Err() without waiting for +// the underlying terminal read to finish. +func ReadPasswordContext(ctx context.Context, hint string, defaultVal string) (string, error) { + return ReadPasswordContextWithMask(ctx, hint, defaultVal, "") +} + +// ReadPasswordContextWithMask is like ReadPasswordContext but echoes the given +// mask string while the user types. +func ReadPasswordContextWithMask(ctx context.Context, hint string, defaultVal string, mask string) (string, error) { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return "", err + } + if !IsTerminal() { + return "", ErrTerminalNotTTY + } + session, err := rawTerminalSessionFactory(hint, true) + if err != nil { + return "", err + } + resultCh := make(chan InputMsg, 1) + go func() { + defer session.Close() + resultCh <- rawLineInputSession(session, defaultVal, mask, rawInputSignalReturnError) + }() + select { + case result := <-resultCh: + return result.String() + case <-ctx.Done(): + session.Abort() + <-resultCh + return "", ctx.Err() + } +} diff --git a/terminal_test.go b/terminal_test.go new file mode 100644 index 0000000..9497144 --- /dev/null +++ b/terminal_test.go @@ -0,0 +1,77 @@ +package stario + +import ( + "bufio" + "context" + "errors" + "io" + "strings" + "testing" + "time" +) + +func installTerminalCheckStub(t *testing.T, ok bool) { + t.Helper() + prev := terminalCheckFunc + terminalCheckFunc = func(fd int) bool { + return ok + } + t.Cleanup(func() { + terminalCheckFunc = prev + }) +} + +func TestIsTerminalFDUsesStub(t *testing.T) { + installTerminalCheckStub(t, true) + if !IsTerminalFD(1) { + t.Fatal("expected terminal check to return true") + } +} + +func TestReadPasswordContextRejectsNonTTY(t *testing.T) { + installTerminalCheckStub(t, false) + _, err := ReadPasswordContext(context.Background(), "pwd: ", "") + if !errors.Is(err, ErrTerminalNotTTY) { + t.Fatalf("expected ErrTerminalNotTTY, got %v", err) + } +} + +func TestReadPasswordContextReadsValue(t *testing.T) { + installTerminalCheckStub(t, true) + installRawInputStub(t, "secret\n", nil) + + got, err := ReadPasswordContext(context.Background(), "pwd: ", "") + if err != nil { + t.Fatalf("ReadPasswordContext returned error: %v", err) + } + if got != "secret" { + t.Fatalf("unexpected password: %q", got) + } +} + +func TestReadPasswordContextCanceledWhileWaiting(t *testing.T) { + installTerminalCheckStub(t, true) + + prevFactory := rawTerminalSessionFactory + pipeReader, pipeWriter := io.Pipe() + rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) { + return &rawTerminalSession{ + input: pipeReader, + reader: bufio.NewReader(pipeReader), + redrawHint: strings.TrimSpace(hint), + }, nil + } + t.Cleanup(func() { + rawTerminalSessionFactory = prevFactory + _ = pipeWriter.Close() + _ = pipeReader.Close() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + _, err := ReadPasswordContext(ctx, "pwd: ", "") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded, got %v", err) + } +} diff --git a/tty_other.go b/tty_other.go new file mode 100644 index 0000000..7958e8b --- /dev/null +++ b/tty_other.go @@ -0,0 +1,10 @@ +//go:build !windows +// +build !windows + +package stario + +import "os" + +func openRawTerminalInput() (*os.File, error) { + return os.OpenFile("/dev/tty", os.O_RDWR, 0) +} diff --git a/tty_windows.go b/tty_windows.go new file mode 100644 index 0000000..ccba7ca --- /dev/null +++ b/tty_windows.go @@ -0,0 +1,10 @@ +//go:build windows +// +build windows + +package stario + +import "os" + +func openRawTerminalInput() (*os.File, error) { + return os.OpenFile("CONIN$", os.O_RDWR, 0) +}