stario: 提升 Go 1.20 基线与交互/队列稳定性

- 提升 go.mod 基线到 Go 1.20,并补齐对应测试
  - 修正 Passwd / PasswdResponseSignal 语义,Ctrl+C 默认退出当前流程
  - 优化 raw terminal redraw、Restore 与 StopUntil 的边界行为
  - 新增 StarPipe、FrameReader/FrameWriter、ReadFullContext/WriteFullContext/CopyContext、IsTerminal/ReadPasswordContext
  - 收口 StarQueue / StarBuffer 语义,删除 EndWrite,统一 Close / Abort 行为
  - 补齐 signal、timeout、queue、terminal、pipe、buffer 的回归测试与 race 覆盖
This commit is contained in:
兔子 2026-04-15 14:35:19 +08:00
parent 3add9183b3
commit c8facb5a03
Signed by: b612
GPG Key ID: 99DD2222B612B612
35 changed files with 3299 additions and 635 deletions

View File

@ -10,6 +10,10 @@ var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be grea
var ErrStarBufferClosed = errors.New("star buffer closed") var ErrStarBufferClosed = errors.New("star buffer closed")
var ErrStarBufferWriteClosed = errors.New("star buffer write closed") var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
// StarBuffer is a blocking ring buffer that implements stream-style reads and writes.
//
// Close marks the write side finished after all payload bytes are sent.
// Abort aborts both sides immediately but still allows buffered bytes to be drained.
type StarBuffer struct { type StarBuffer struct {
datas []byte datas []byte
pStart uint64 pStart uint64
@ -36,18 +40,21 @@ func NewStarBuffer(cap uint64) (*StarBuffer, error) {
return rtnBuffer, nil return rtnBuffer, nil
} }
// Free returns the remaining writable capacity.
func (star *StarBuffer) Free() uint64 { func (star *StarBuffer) Free() uint64 {
star.mu.Lock() star.mu.Lock()
defer star.mu.Unlock() defer star.mu.Unlock()
return star.cap - star.size return star.cap - star.size
} }
// Cap returns the fixed buffer capacity.
func (star *StarBuffer) Cap() uint64 { func (star *StarBuffer) Cap() uint64 {
star.mu.Lock() star.mu.Lock()
defer star.mu.Unlock() defer star.mu.Unlock()
return star.cap return star.cap
} }
// Len returns the currently buffered byte count.
func (star *StarBuffer) Len() uint64 { func (star *StarBuffer) Len() uint64 {
star.mu.Lock() star.mu.Lock()
defer star.mu.Unlock() defer star.mu.Unlock()
@ -89,19 +96,21 @@ func (star *StarBuffer) putByte(data byte) error {
return nil return nil
} }
func (star *StarBuffer) EndWrite() error { // Close closes only the write side and satisfies the usual io.Closer-style
// "producer finished" semantics.
//
// Buffered bytes remain readable until drained; afterwards reads return io.EOF.
func (star *StarBuffer) Close() error {
star.mu.Lock() star.mu.Lock()
defer star.mu.Unlock() defer star.mu.Unlock()
if star.isClose { return star.closeWriteLocked()
return ErrStarBufferClosed
}
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil
} }
func (star *StarBuffer) Close() error { // Abort aborts the buffer and wakes blocked readers/writers immediately.
//
// Buffered bytes remain readable until drained; subsequent writes fail with
// ErrStarBufferClosed.
func (star *StarBuffer) Abort() error {
star.mu.Lock() star.mu.Lock()
defer star.mu.Unlock() defer star.mu.Unlock()
star.isClose = true star.isClose = true
@ -111,10 +120,17 @@ func (star *StarBuffer) Close() error {
return nil return nil
} }
func (star *StarBuffer) Read(buf []byte) (int, error) { func (star *StarBuffer) closeWriteLocked() error {
if buf == nil { if star.isClose {
return 0, errors.New("buffer is nil") return ErrStarBufferClosed
} }
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil
}
func (star *StarBuffer) Read(buf []byte) (int, error) {
if len(buf) == 0 { if len(buf) == 0 {
return 0, nil return 0, nil
} }
@ -140,9 +156,6 @@ func (star *StarBuffer) Read(buf []byte) (int, error) {
} }
func (star *StarBuffer) Write(bts []byte) (int, error) { func (star *StarBuffer) Write(bts []byte) (int, error) {
if bts == nil {
return 0, star.EndWrite()
}
if len(bts) == 0 { if len(bts) == 0 {
return 0, nil return 0, nil
} }

44
circle_benchmark_test.go Normal file
View File

@ -0,0 +1,44 @@
package stario
import (
"io"
"testing"
)
func BenchmarkStarBufferByteRoundTrip(b *testing.B) {
buf, err := NewStarBuffer(4096)
if err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.SetBytes(1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := buf.putByte('a'); err != nil {
b.Fatal(err)
}
if _, err := buf.getByte(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkStarBufferChunkRoundTrip(b *testing.B) {
buf, err := NewStarBuffer(8192)
if err != nil {
b.Fatal(err)
}
payload := []byte("hello world b612 hello world b612 b612 b612 b612 b612 b612")
scratch := make([]byte, len(payload))
b.ReportAllocs()
b.SetBytes(int64(len(payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := buf.Write(payload); err != nil {
b.Fatal(err)
}
if _, err := io.ReadFull(buf, scratch); err != nil {
b.Fatal(err)
}
}
}

View File

@ -2,11 +2,8 @@ package stario
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"sync/atomic"
"testing" "testing"
"time"
) )
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) { func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
@ -19,7 +16,7 @@ func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
} }
} }
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) { func TestStarBufferCloseDrainsThenEOF(t *testing.T) {
buf, err := NewStarBuffer(4) buf, err := NewStarBuffer(4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -27,7 +24,7 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
if _, err := buf.Write([]byte("abcd")); err != nil { if _, err := buf.Write([]byte("abcd")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := buf.EndWrite(); err != nil { if err := buf.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -46,11 +43,11 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
} }
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed { if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
t.Fatalf("unexpected write error after EndWrite: %v", err) t.Fatalf("unexpected write error after Close: %v", err)
} }
} }
func TestStarBufferCloseAllowsDrain(t *testing.T) { func TestStarBufferAbortAllowsDrain(t *testing.T) {
buf, err := NewStarBuffer(4) buf, err := NewStarBuffer(4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -58,7 +55,7 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
if _, err := buf.Write([]byte("ab")); err != nil { if _, err := buf.Write([]byte("ab")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := buf.Close(); err != nil { if err := buf.Abort(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -68,102 +65,38 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) { if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n]) t.Fatalf("unexpected payload after abort: n=%d data=%q", n, got[:n])
} }
n, err = buf.Read(got) n, err = buf.Read(got)
if n != 0 || err != io.EOF { if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err) t.Fatalf("expected EOF after draining aborted buffer, got n=%d err=%v", n, err)
} }
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed { if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
t.Fatalf("unexpected write error after Close: %v", err) t.Fatalf("unexpected write error after Abort: %v", err)
} }
} }
func Test_Circle(t *testing.T) { func TestStarBufferNilReadIsNoOp(t *testing.T) {
buf, err := NewStarBuffer(2048) buf, err := NewStarBuffer(4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
go func() { if n, err := buf.Read(nil); n != 0 || err != nil {
for { t.Fatalf("expected nil read to be a no-op, got n=%d err=%v", n, err)
//fmt.Println("write start")
buf.Write([]byte("中华人民共和国\n"))
//fmt.Println("write success")
time.Sleep(time.Millisecond * 50)
} }
}()
cpp := ""
go func() {
time.Sleep(time.Second * 3)
for {
cache := make([]byte, 64)
ints, err := buf.Read(cache)
if err != nil {
fmt.Println("read error", err)
return
}
if ints != 0 {
cpp += string(cache[:ints])
}
}
}()
time.Sleep(time.Second * 13)
fmt.Println(cpp)
} }
func Test_Circle_Speed(t *testing.T) { func TestStarBufferNilWriteIsNoOp(t *testing.T) {
buf, err := NewStarBuffer(1048976) buf, err := NewStarBuffer(4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
count := uint64(0) if n, err := buf.Write(nil); n != 0 || err != nil {
for i := 1; i <= 10; i++ { t.Fatalf("expected nil write to be a no-op, got n=%d err=%v", n, err)
go func() {
for {
buf.putByte('a')
} }
}() if _, err := buf.Write([]byte("ab")); err != nil {
} t.Fatalf("nil write must not end the write side, got %v", err)
for i := 1; i <= 10; i++ {
go func() {
for {
_, err := buf.getByte()
if err == nil {
atomic.AddUint64(&count, 1)
} }
} }
}()
}
time.Sleep(time.Second * 10)
fmt.Println(count)
}
func Test_Circle_Speed2(t *testing.T) {
buf, err := NewStarBuffer(8192)
if err != nil {
t.Fatal(err)
}
count := uint64(0)
for i := 1; i <= 10; i++ {
go func() {
for {
buf.Write([]byte("hello world b612 hello world b612 b612 b612 b612 b612 b612"))
}
}()
}
for i := 1; i <= 10; i++ {
go func() {
for {
mybuf := make([]byte, 1024)
j, err := buf.Read(mybuf)
if err == nil {
atomic.AddUint64(&count, uint64(j))
}
}
}()
}
time.Sleep(time.Second * 10)
fmt.Println(float64(count) / 10 / 1024 / 1024)
}

126
fn.go
View File

@ -1,55 +1,111 @@
package stario package stario
import ( import (
"context"
"errors" "errors"
"time" "time"
) )
// ERR_TIMEOUT is the legacy timeout sentinel used by WaitUntilTimeout*.
var ERR_TIMEOUT = errors.New("TIME OUT") var ERR_TIMEOUT = errors.New("TIME OUT")
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error { // WaitUntilContext runs fn and returns either its result or the context error,
var err error // whichever happens first.
finished := make(chan struct{}) func WaitUntilContext(ctx context.Context, fn func(context.Context) error) error {
imout := make(chan struct{}) if ctx == nil {
ctx = context.Background()
}
finished := make(chan error, 1)
go func() { go func() {
err = fn(imout) finished <- fn(ctx)
finished <- struct{}{}
}() }()
select { select {
case <-finished: case err := <-finished:
return err return err
case <-time.After(tm): case <-ctx.Done():
close(imout) return ctx.Err()
return ERR_TIMEOUT
} }
} }
func WaitUntilFinished(fn func() error) <-chan error { // WaitUntilContextFinished is the asynchronous form of WaitUntilContext.
finished := make(chan error) func WaitUntilContextFinished(ctx context.Context, fn func(context.Context) error) <-chan error {
result := make(chan error, 1)
go func() { go func() {
err := fn() result <- WaitUntilContext(ctx, fn)
finished <- err close(result)
}()
return finished
}
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
var err error
finished := make(chan struct{})
result := make(chan error)
imout := make(chan struct{})
go func() {
err = fn(imout)
finished <- struct{}{}
}()
go func() {
select {
case <-finished:
result <- err
case <-time.After(tm):
close(imout)
result <- ERR_TIMEOUT
}
}() }()
return result return result
} }
// WaitUntilContextDone adapts a done-channel worker to a context-based wait.
func WaitUntilContextDone(ctx context.Context, fn func(<-chan struct{}) error) error {
if ctx == nil {
ctx = context.Background()
}
return WaitUntilContext(ctx, func(context.Context) error {
return fn(ctx.Done())
})
}
// WaitUntilContextDoneFinished is the asynchronous form of WaitUntilContextDone.
func WaitUntilContextDoneFinished(ctx context.Context, fn func(<-chan struct{}) error) <-chan error {
return WaitUntilContextFinished(ctx, func(ctx context.Context) error {
return fn(ctx.Done())
})
}
// WaitUntilTimeout is a legacy timeout helper kept for compatibility.
//
// The provided stop channel must be treated as receive-only by callers. New
// code should prefer WaitUntilContext or WaitUntilContextDone.
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error {
ctx, cancel := context.WithTimeout(context.Background(), tm)
defer cancel()
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
return fn(bridgeDoneChan(done))
})
if errors.Is(err, context.DeadlineExceeded) {
return ERR_TIMEOUT
}
return err
}
// WaitUntilFinished runs fn asynchronously and returns its eventual result.
func WaitUntilFinished(fn func() error) <-chan error {
return WaitUntilContextFinished(context.Background(), func(ctx context.Context) error {
return fn()
})
}
// WaitUntilTimeoutFinished is the asynchronous form of WaitUntilTimeout.
//
// The provided stop channel must be treated as receive-only by callers. New
// code should prefer WaitUntilContextFinished or WaitUntilContextDoneFinished.
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
result := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), tm)
defer cancel()
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
return fn(bridgeDoneChan(done))
})
if errors.Is(err, context.DeadlineExceeded) {
err = ERR_TIMEOUT
}
result <- err
close(result)
}()
return result
}
func bridgeDoneChan(done <-chan struct{}) chan struct{} {
stop := make(chan struct{})
if done == nil {
return stop
}
go func() {
<-done
close(stop)
}()
return stop
}

174
fn_test.go Normal file
View File

@ -0,0 +1,174 @@
package stario
import (
"context"
"errors"
"testing"
"time"
)
func TestWaitUntilContextReturnsWorkerError(t *testing.T) {
want := errors.New("worker failed")
err := WaitUntilContext(context.Background(), func(ctx context.Context) error {
return want
})
if !errors.Is(err, want) {
t.Fatalf("unexpected error: got %v want %v", err, want)
}
}
func TestWaitUntilContextReturnsDeadlineExceeded(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
workerReturned := make(chan struct{})
err := WaitUntilContext(ctx, func(ctx context.Context) error {
<-ctx.Done()
close(workerReturned)
return nil
})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected context error: %v", err)
}
select {
case <-workerReturned:
case <-time.After(200 * time.Millisecond):
t.Fatal("worker did not return after context deadline")
}
}
func TestWaitUntilContextFinishedReturnsCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := WaitUntilContextFinished(ctx, func(ctx context.Context) error {
<-ctx.Done()
return nil
})
err, ok := <-result
if !ok {
t.Fatal("result channel closed without value")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected context error: %v", err)
}
if _, ok := <-result; ok {
t.Fatal("result channel should be closed after delivering the outcome")
}
}
func TestWaitUntilContextDoneBridgesDoneChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
doneClosed := make(chan struct{})
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
<-done
close(doneClosed)
return nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected context error: %v", err)
}
select {
case <-doneClosed:
case <-time.After(200 * time.Millisecond):
t.Fatal("done channel was not closed when context canceled")
}
}
func TestWaitUntilContextDoneFinishedReturnsCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
doneClosed := make(chan struct{})
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
result := WaitUntilContextDoneFinished(ctx, func(done <-chan struct{}) error {
<-done
close(doneClosed)
return nil
})
err, ok := <-result
if !ok {
t.Fatal("result channel closed without value")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected context error: %v", err)
}
select {
case <-doneClosed:
case <-time.After(200 * time.Millisecond):
t.Fatal("done channel was not closed when context canceled")
}
if _, ok := <-result; ok {
t.Fatal("result channel should be closed after delivering the outcome")
}
}
func TestWaitUntilTimeoutReturnsWorkerError(t *testing.T) {
want := errors.New("worker failed")
err := WaitUntilTimeout(time.Second, func(stop chan struct{}) error {
return want
})
if !errors.Is(err, want) {
t.Fatalf("unexpected error: got %v want %v", err, want)
}
}
func TestWaitUntilTimeoutTimesOutWithoutBlockingWorkerReturn(t *testing.T) {
workerReturned := make(chan struct{})
err := WaitUntilTimeout(20*time.Millisecond, func(stop chan struct{}) error {
<-stop
close(workerReturned)
return nil
})
if !errors.Is(err, ERR_TIMEOUT) {
t.Fatalf("unexpected timeout error: %v", err)
}
select {
case <-workerReturned:
case <-time.After(200 * time.Millisecond):
t.Fatal("worker did not return after timeout signal")
}
}
func TestWaitUntilFinishedReturnsWorkerError(t *testing.T) {
want := errors.New("worker failed")
err := <-WaitUntilFinished(func() error {
return want
})
if !errors.Is(err, want) {
t.Fatalf("unexpected error: got %v want %v", err, want)
}
}
func TestWaitUntilTimeoutFinishedReturnsTimeout(t *testing.T) {
result := WaitUntilTimeoutFinished(20*time.Millisecond, func(stop chan struct{}) error {
<-stop
return nil
})
err, ok := <-result
if !ok {
t.Fatal("result channel closed without value")
}
if !errors.Is(err, ERR_TIMEOUT) {
t.Fatalf("unexpected timeout error: %v", err)
}
if _, ok := <-result; ok {
t.Fatal("result channel should be closed after delivering the outcome")
}
}
func TestWaitUntilTimeoutFinishedReturnsWorkerError(t *testing.T) {
want := errors.New("worker failed")
err, ok := <-WaitUntilTimeoutFinished(time.Second, func(stop chan struct{}) error {
return want
})
if !ok {
t.Fatal("result channel closed without value")
}
if !errors.Is(err, want) {
t.Fatalf("unexpected error: got %v want %v", err, want)
}
}

191
frameio.go Normal file
View File

@ -0,0 +1,191 @@
package stario
import (
"errors"
"io"
)
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
// FrameReader.
const DefaultFrameReaderBufferSize = 32 * 1024
type frameReaderConnKey struct{}
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
type FrameWriter struct {
writer io.Writer
queue *StarQueue
}
// NewFrameWriter creates a framing writer backed by queue. When queue is nil, a
// default StarQueue is created.
func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter {
if queue == nil {
queue = NewQueue()
}
return &FrameWriter{
writer: writer,
queue: queue,
}
}
// WriteFrame writes one framed payload.
func (writer *FrameWriter) WriteFrame(payload []byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFrame(writer.writer, payload)
}
// WriteFrameBuffers writes one framed payload using net.Buffers when possible.
func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFrameBuffers(writer.writer, payload)
}
// WriteFramesBuffers writes multiple framed payloads in one batch when
// possible.
func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFramesBuffers(writer.writer, payloads...)
}
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
type FrameReader struct {
reader io.Reader
queue *StarQueue
connKey interface{}
readSize int
pending [][]byte
pendingErr error
}
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
// default StarQueue is created.
func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader {
if queue == nil {
queue = NewQueue()
}
return &FrameReader{
reader: reader,
queue: queue,
connKey: &frameReaderConnKey{},
readSize: DefaultFrameReaderBufferSize,
}
}
// SetReadBufferSize updates the underlying transport read chunk size.
func (reader *FrameReader) SetReadBufferSize(size int) {
if reader == nil || size <= 0 {
return
}
reader.readSize = size
}
// SetConnKey overrides the internal queue connection key.
func (reader *FrameReader) SetConnKey(conn interface{}) error {
if reader == nil {
return io.ErrClosedPipe
}
if conn == nil {
reader.connKey = &frameReaderConnKey{}
return nil
}
if err := validateConnKey(conn); err != nil {
return err
}
reader.connKey = conn
return nil
}
// Next returns the next framed payload.
func (reader *FrameReader) Next() ([]byte, error) {
if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, io.ErrClosedPipe
}
if len(reader.pending) > 0 {
return reader.popPending(), nil
}
if reader.pendingErr != nil {
err := reader.pendingErr
reader.pendingErr = nil
return nil, err
}
if reader.readSize <= 0 {
reader.readSize = DefaultFrameReaderBufferSize
}
buf := make([]byte, reader.readSize)
for {
n, readErr := reader.reader.Read(buf)
if n > 0 {
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
reader.pending = append(reader.pending, msg.Msg)
return nil
})
err := reader.normalizeReadErr(parseErr, readErr)
if len(reader.pending) > 0 {
reader.pendingErr = err
return reader.popPending(), nil
}
if err != nil {
return nil, err
}
}
if readErr != nil {
return nil, reader.normalizeReadErr(nil, readErr)
}
}
}
func (reader *FrameReader) popPending() []byte {
next := reader.pending[0]
reader.pending = reader.pending[1:]
return next
}
func joinFrameReaderError(parseErr error, readErr error) error {
switch {
case parseErr == nil:
return readErr
case readErr == nil:
return parseErr
default:
return errors.Join(parseErr, readErr)
}
}
func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error {
if errors.Is(readErr, io.EOF) && reader.hasBufferedState() {
reader.clearBufferedState()
readErr = io.ErrUnexpectedEOF
}
return joinFrameReaderError(parseErr, readErr)
}
func (reader *FrameReader) hasBufferedState() bool {
if reader == nil || reader.queue == nil {
return false
}
stateAny, ok := reader.queue.states.Load(reader.connKey)
if !ok {
return false
}
state, ok := stateAny.(*queConnState)
if !ok {
return false
}
state.mu.Lock()
defer state.mu.Unlock()
return len(state.buf) > 0
}
func (reader *FrameReader) clearBufferedState() {
if reader == nil || reader.queue == nil {
return
}
reader.queue.states.Delete(reader.connKey)
}

105
frameio_test.go Normal file
View File

@ -0,0 +1,105 @@
package stario
import (
"bytes"
"errors"
"io"
"testing"
)
type chunkedReader struct {
data []byte
max int
}
func (reader *chunkedReader) Read(p []byte) (int, error) {
if len(reader.data) == 0 {
return 0, io.EOF
}
if reader.max > 0 && len(p) > reader.max {
p = p[:reader.max]
}
n := copy(p, reader.data)
reader.data = reader.data[n:]
return n, nil
}
func TestFrameWriterReaderRoundTrip(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("one")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil)
cases := []string{"one", "two", "three"}
for _, want := range cases {
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != want {
t.Fatalf("unexpected frame: got %q want %q", got, want)
}
}
if _, err := reader.Next(); !errors.Is(err, io.EOF) {
t.Fatalf("expected EOF after draining frames, got %v", err)
}
}
func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil)
reader.SetReadBufferSize(3)
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame: %q", got)
}
}
func TestFrameWriterNilWriterFails(t *testing.T) {
writer := NewFrameWriter(nil, nil)
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
}
}
func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) {
queue := NewQueue()
frame := queue.BuildMessage([]byte("hello"))
reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil)
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err)
}
}
func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) {
queue := NewQueue()
first := queue.BuildMessage([]byte("one"))
second := queue.BuildMessage([]byte("two"))
wire := append(append([]byte{}, first...), second[:len(second)-1]...)
reader := NewFrameReader(bytes.NewReader(wire), nil)
got, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first frame: %v", err)
}
if string(got) != "one" {
t.Fatalf("unexpected first frame: %q", got)
}
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err)
}
}

6
go.mod
View File

@ -1,5 +1,7 @@
module b612.me/stario module b612.me/stario
go 1.16 go 1.20
require golang.org/x/crypto v0.26.0 require golang.org/x/term v0.23.0
require golang.org/x/sys v0.23.0 // indirect

63
go.sum
View File

@ -1,67 +1,4 @@
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

235
io.go
View File

@ -3,11 +3,12 @@ package stario
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/term"
"io"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
) )
type InputMsg struct { type InputMsg struct {
@ -16,30 +17,54 @@ type InputMsg struct {
skipSliceSigErr bool skipSliceSigErr bool
} }
type rawInputSignalMode uint8
const (
rawInputSignalIgnore rawInputSignalMode = iota
rawInputSignalExit
rawInputSignalReturnError
)
type rawTerminalSession struct { type rawTerminalSession struct {
fd int fd int
state *terminal.State state *term.State
reader *bufio.Reader reader *bufio.Reader
input io.Closer
redrawHint string redrawHint string
printNewline bool printNewline bool
mu sync.Mutex
} }
var rawTerminalSessionFactory = newRawTerminalSession
var inputSignalHandler = signal
// Passwd reads one password-style line in raw mode.
//
// When the user presses an input signal such as Ctrl+C, this compatibility
// entry exits the current flow and returns an empty message with a nil error.
func Passwd(hint string, defaultVal string) InputMsg { func Passwd(hint string, defaultVal string) InputMsg {
return passwd(hint, defaultVal, "", true) return passwd(hint, defaultVal, "", rawInputSignalExit)
} }
// PasswdWithMask is like Passwd but echoes the provided mask string.
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg { func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
return passwd(hint, defaultVal, mask, true) return passwd(hint, defaultVal, mask, rawInputSignalExit)
} }
// PasswdResponseSignal is like Passwd but preserves input-signal errors for
// callers that need to distinguish Ctrl+C / Ctrl+Z style exits.
func PasswdResponseSignal(hint string, defaultVal string) InputMsg { func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
return passwd(hint, defaultVal, "", true) return passwd(hint, defaultVal, "", rawInputSignalReturnError)
} }
// PasswdResponseSignalWithMask is like PasswdResponseSignal but echoes the
// provided mask string.
func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg { func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg {
return passwd(hint, defaultVal, mask, true) return passwd(hint, defaultVal, mask, rawInputSignalReturnError)
} }
// MessageBoxRaw reads one line in raw mode without treating control keys as
// exit signals.
func MessageBoxRaw(hint string, defaultVal string) InputMsg { func MessageBoxRaw(hint string, defaultVal string) InputMsg {
return messageBox(hint, defaultVal) return messageBox(hint, defaultVal)
} }
@ -48,35 +73,64 @@ func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession,
if hint != "" { if hint != "" {
fmt.Print(hint) fmt.Print(hint)
} }
fd := int(os.Stdin.Fd()) input, err := openRawTerminalInput()
state, err := terminal.MakeRaw(fd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fd := int(input.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
_ = input.Close()
return nil, err
}
return &rawTerminalSession{ return &rawTerminalSession{
fd: fd, fd: fd,
state: state, state: state,
reader: bufio.NewReader(os.Stdin), reader: bufio.NewReader(input),
input: input,
redrawHint: promptRedrawHint(hint), redrawHint: promptRedrawHint(hint),
printNewline: printNewline, printNewline: printNewline,
}, nil }, nil
} }
func (session *rawTerminalSession) Close() { func (session *rawTerminalSession) Close() {
if session == nil || session.state == nil { if session == nil {
return return
} }
_ = terminal.Restore(session.fd, session.state) session.mu.Lock()
defer session.mu.Unlock()
if session.state != nil {
_ = term.Restore(session.fd, session.state)
session.state = nil
}
if session.printNewline { if session.printNewline {
fmt.Println() fmt.Println()
session.printNewline = false
}
if session.input != nil {
_ = session.input.Close()
session.input = nil
} }
} }
func (session *rawTerminalSession) Restore() error { func (session *rawTerminalSession) Restore() error {
if session == nil || session.state == nil { if session == nil {
return nil return nil
} }
return terminal.Restore(session.fd, session.state) session.mu.Lock()
defer session.mu.Unlock()
if session.state == nil {
return nil
}
if err := term.Restore(session.fd, session.state); err != nil {
return err
}
session.state = nil
return nil
}
func (session *rawTerminalSession) Abort() {
session.Close()
} }
func promptRedrawHint(hint string) string { func promptRedrawHint(hint string) string {
@ -101,6 +155,20 @@ func renderRawEcho(ioBuf []rune, mask string) string {
return strings.Repeat(mask, len(ioBuf)) return strings.Repeat(mask, len(ioBuf))
} }
func rawEchoRenderUnit(r rune, mask string, maskWidth int) (string, int, bool) {
if mask != "" {
if maskWidth <= 0 {
return "", 0, false
}
return mask, maskWidth, true
}
width := runeDisplayWidth(r)
if width <= 0 {
return "", 0, false
}
return string(r), width, true
}
func redrawPromptLine(hint string, echo string, lastWidth int) int { func redrawPromptLine(hint string, echo string, lastWidth int) int {
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo) nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
clearWidth := lastWidth clearWidth := lastWidth
@ -121,6 +189,22 @@ func redrawPromptLine(hint string, echo string, lastWidth int) int {
return nowWidth return nowWidth
} }
func erasePromptTail(width int) {
if width <= 0 {
return
}
backtrack := strings.Repeat("\b", width)
fmt.Print(backtrack)
fmt.Print(strings.Repeat(" ", width))
fmt.Print(backtrack)
}
func redrawPromptEcho(hint string, ioBuf []rune, mask string, lastWidth int) (int, int) {
echo := renderRawEcho(ioBuf, mask)
echoWidth := stringDisplayWidth(echo)
return echoWidth, redrawPromptLine(hint, echo, lastWidth)
}
func stringDisplayWidth(text string) int { func stringDisplayWidth(text string) int {
width := 0 width := 0
for _, r := range text { for _, r := range text {
@ -157,46 +241,79 @@ func isWideRune(r rune) bool {
(r >= 0x20000 && r <= 0x3fffd)) (r >= 0x20000 && r <= 0x3fffd))
} }
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { func signalInputResult(mode rawInputSignalMode, err error) InputMsg {
session, err := newRawTerminalSession(hint, true) switch mode {
case rawInputSignalExit:
return InputMsg{msg: "", err: nil}
case rawInputSignalReturnError:
return InputMsg{msg: "", err: err}
default:
return InputMsg{msg: "", err: nil}
}
}
func rawLineInput(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
session, err := rawTerminalSessionFactory(hint, true)
if err != nil { if err != nil {
return InputMsg{msg: "", err: err} return InputMsg{msg: "", err: err}
} }
defer session.Close() defer session.Close()
return rawLineInputSession(session, defaultVal, mask, signalMode)
}
func rawLineInputSession(session *rawTerminalSession, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
if session == nil || session.reader == nil {
return InputMsg{msg: "", err: io.ErrClosedPipe}
}
ioBuf := make([]rune, 0, 16) ioBuf := make([]rune, 0, 16)
lastWidth := 0 promptWidth := stringDisplayWidth(session.redrawHint)
maskWidth := stringDisplayWidth(mask)
echoWidth := 0
lastWidth := promptWidth
for { for {
b, _, err := session.reader.ReadRune() b, _, err := session.reader.ReadRune()
if err != nil { if err != nil {
return InputMsg{msg: "", err: err} return InputMsg{msg: "", err: err}
} }
if handleSignal && isSignal(b) { if signalMode != rawInputSignalIgnore && isSignal(b) {
if runtime.GOOS != "windows" { session.Close()
if err := session.Restore(); err != nil { if signalMode == rawInputSignalExit {
return InputMsg{msg: "", err: err} return signalInputResult(signalMode, nil)
} }
} return signalInputResult(signalMode, inputSignalHandler(b))
if err := signal(b); err != nil {
return InputMsg{msg: "", err: err}
}
continue
} }
switch b { switch b {
case 0x0d, 0x0a: case 0x0d, 0x0a:
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil} return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
case 0x08, 0x7F: case 0x08, 0x7F:
if len(ioBuf) > 0 { if len(ioBuf) > 0 {
removed := ioBuf[len(ioBuf)-1]
ioBuf = ioBuf[:len(ioBuf)-1] ioBuf = ioBuf[:len(ioBuf)-1]
if _, removedWidth, ok := rawEchoRenderUnit(removed, mask, maskWidth); ok {
erasePromptTail(removedWidth)
echoWidth -= removedWidth
if echoWidth < 0 {
echoWidth = 0
}
lastWidth = promptWidth + echoWidth
continue
}
} }
default: default:
ioBuf = append(ioBuf, b) ioBuf = append(ioBuf, b)
if appendText, appendWidth, ok := rawEchoRenderUnit(b, mask, maskWidth); ok {
fmt.Print(appendText)
echoWidth += appendWidth
lastWidth = promptWidth + echoWidth
continue
} }
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth) }
echoWidth, lastWidth = redrawPromptEcho(session.redrawHint, ioBuf, mask, lastWidth)
} }
} }
func messageBox(hint string, defaultVal string) InputMsg { func messageBox(hint string, defaultVal string) InputMsg {
return rawLineInput(hint, defaultVal, "", false) return rawLineInput(hint, defaultVal, "", rawInputSignalIgnore)
} }
func isSignal(s rune) bool { func isSignal(s rune) bool {
@ -208,10 +325,12 @@ func isSignal(s rune) bool {
} }
} }
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { func passwd(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
return rawLineInput(hint, defaultVal, mask, handleSignal) return rawLineInput(hint, defaultVal, mask, signalMode)
} }
// MessageBox reads one line in cooked mode and falls back to defaultVal when
// the trimmed input is empty.
func MessageBox(hint string, defaultVal string) InputMsg { func MessageBox(hint string, defaultVal string) InputMsg {
if hint != "" { if hint != "" {
fmt.Print(hint) fmt.Print(hint)
@ -264,7 +383,7 @@ func (im InputMsg) sliceFn(sep string, fn func(string) (interface{}, error)) ([]
return res, err return res, err
} }
for _, v := range data { for _, v := range data {
code, err := fn(v) code, err := fn(strings.TrimSpace(v))
if err != nil && !im.skipSliceSigErr { if err != nil && !im.skipSliceSigErr {
return nil, err return nil, err
} else if err == nil { } else if err == nil {
@ -428,7 +547,8 @@ func (im InputMsg) MustFloat32() float32 {
func (im InputMsg) SliceFloat32(sep string) ([]float32, error) { func (im InputMsg) SliceFloat32(sep string) ([]float32, error) {
data, err := im.sliceFn(sep, func(v string) (interface{}, error) { data, err := im.sliceFn(sep, func(v string) (interface{}, error) {
return strconv.ParseFloat(v, 32) f, err := strconv.ParseFloat(v, 32)
return float32(f), err
}) })
var res []float32 var res []float32
for _, v := range data { for _, v := range data {
@ -477,13 +597,47 @@ func YesNoE(hint string, defaults bool) (bool, error) {
} }
} }
func buildTriggerPrefixTable(triggerRunes []rune) []int {
if len(triggerRunes) == 0 {
return nil
}
prefix := make([]int, len(triggerRunes))
for i := 1; i < len(triggerRunes); i++ {
j := prefix[i-1]
for j > 0 && triggerRunes[i] != triggerRunes[j] {
j = prefix[j-1]
}
if triggerRunes[i] == triggerRunes[j] {
j++
}
prefix[i] = j
}
return prefix
}
func advanceTriggerIndex(triggerRunes []rune, prefix []int, current int, input rune) (int, bool) {
if len(triggerRunes) == 0 {
return 0, true
}
for current > 0 && input != triggerRunes[current] {
current = prefix[current-1]
}
if input == triggerRunes[current] {
current++
if current == len(triggerRunes) {
return current, true
}
}
return current, false
}
// StopUntil keeps reading raw input until trigger is matched.
// When trigger == "", it returns after the first key press, which is used for
// "press any key to continue" style prompts.
func StopUntil(hint string, trigger string, repeat bool) error { func StopUntil(hint string, trigger string, repeat bool) error {
triggerRunes := []rune(trigger) triggerRunes := []rune(trigger)
pressLen := len(triggerRunes) prefix := buildTriggerPrefixTable(triggerRunes)
if trigger == "" { session, err := rawTerminalSessionFactory(hint, false)
pressLen = 1
}
session, err := newRawTerminalSession(hint, false)
if err != nil { if err != nil {
return err return err
} }
@ -497,11 +651,12 @@ func StopUntil(hint string, trigger string, repeat bool) error {
if trigger == "" { if trigger == "" {
break break
} }
if b == triggerRunes[i] { next, complete := advanceTriggerIndex(triggerRunes, prefix, i, b)
i++ if complete {
if i == pressLen {
break break
} }
if next > 0 {
i = next
continue continue
} }
i = 0 i = 0

81
io_benchmark_test.go Normal file
View File

@ -0,0 +1,81 @@
package stario
import (
"strings"
"testing"
)
func benchmarkRawRedrawAppend(b *testing.B, current []rune, next rune, mask string) {
b.Run("legacy", func(b *testing.B) {
ioBuf := make([]rune, len(current)+1)
copy(ioBuf, current)
ioBuf[len(current)] = next
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
echo := renderRawEcho(ioBuf, mask)
_ = stringDisplayWidth(echo)
}
})
b.Run("fast", func(b *testing.B) {
maskWidth := stringDisplayWidth(mask)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, ok := rawEchoRenderUnit(next, mask, maskWidth)
if !ok {
b.Fatal("fast path rejected append rune")
}
}
})
}
func benchmarkRawRedrawBackspace(b *testing.B, current []rune, mask string) {
if len(current) == 0 {
b.Fatal("backspace benchmark requires non-empty input")
}
last := current[len(current)-1]
trimmed := current[:len(current)-1]
b.Run("legacy", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
echo := renderRawEcho(trimmed, mask)
_ = stringDisplayWidth(echo)
}
})
b.Run("fast", func(b *testing.B) {
maskWidth := stringDisplayWidth(mask)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, ok := rawEchoRenderUnit(last, mask, maskWidth)
if !ok {
b.Fatal("fast path rejected backspace rune")
}
}
})
}
func BenchmarkRawRedrawAppendPlainLong(b *testing.B) {
current := []rune(strings.Repeat("a", 4096))
benchmarkRawRedrawAppend(b, current, 'b', "")
}
func BenchmarkRawRedrawAppendMaskLong(b *testing.B) {
current := []rune(strings.Repeat("a", 4096))
benchmarkRawRedrawAppend(b, current, 'b', "[]")
}
func BenchmarkRawRedrawBackspacePlainLong(b *testing.B) {
current := []rune(strings.Repeat("a", 4096))
benchmarkRawRedrawBackspace(b, current, "")
}
func BenchmarkRawRedrawBackspaceMaskLong(b *testing.B) {
current := []rune(strings.Repeat("a", 4096))
benchmarkRawRedrawBackspace(b, current, "[]")
}

177
io_context.go Normal file
View File

@ -0,0 +1,177 @@
package stario
import (
"context"
"errors"
"io"
)
const defaultCopyContextBufferSize = 32 * 1024
type readContextResult struct {
data []byte
err error
}
type writeContextResult struct {
n int
err error
}
// ReadFullContext reads exactly len(buf) bytes unless the context is canceled
// or the underlying reader returns an error.
//
// If ctx is canceled while the underlying Read call is blocked, ReadFullContext
// returns ctx.Err() without waiting for that call to finish. The underlying
// reader may still complete asynchronously afterwards.
func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) {
if reader == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
total := 0
for total < len(buf) {
chunk, err := readContext(ctx, reader, len(buf)-total)
if len(chunk) > 0 {
total += copy(buf[total:], chunk)
}
if err != nil {
if errors.Is(err, io.EOF) {
if total > 0 {
return total, io.ErrUnexpectedEOF
}
return total, io.EOF
}
return total, err
}
if len(chunk) == 0 {
return total, io.ErrNoProgress
}
}
return total, nil
}
// WriteFullContext writes the full payload unless the context is canceled or
// the underlying writer returns an error.
//
// If ctx is canceled while the underlying Write call is blocked,
// WriteFullContext returns ctx.Err() without waiting for that call to finish.
// The underlying writer may still complete asynchronously afterwards.
func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
if writer == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
total := 0
for total < len(data) {
written, err := writeContext(ctx, writer, data[total:])
if written > 0 {
total += written
}
if err != nil {
return total, err
}
if written == 0 {
return total, io.ErrNoProgress
}
}
return total, nil
}
// CopyContext copies from src to dst until EOF, context cancellation, or a
// non-EOF error occurs.
//
// If ctx is canceled while the current read or write is blocked, CopyContext
// returns ctx.Err() without waiting for that operation to finish.
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
if dst == nil || src == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
var copied int64
for {
chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize)
if len(chunk) > 0 {
written, writeErr := WriteFullContext(ctx, dst, chunk)
copied += int64(written)
if writeErr != nil {
return copied, writeErr
}
}
if readErr != nil {
if errors.Is(readErr, io.EOF) {
return copied, nil
}
return copied, readErr
}
if len(chunk) == 0 {
return copied, io.ErrNoProgress
}
}
}
func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) {
if size <= 0 {
return nil, nil
}
resultCh := make(chan readContextResult, 1)
go func() {
tmp := make([]byte, size)
n, err := reader.Read(tmp)
if n > 0 {
tmp = tmp[:n]
} else {
tmp = nil
}
resultCh <- readContextResult{data: tmp, err: err}
}()
select {
case result := <-resultCh:
return result.data, result.err
case <-ctx.Done():
select {
case result := <-resultCh:
return result.data, result.err
default:
return nil, ctx.Err()
}
}
}
func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
if len(data) == 0 {
return 0, nil
}
payload := append([]byte(nil), data...)
resultCh := make(chan writeContextResult, 1)
go func() {
n, err := writer.Write(payload)
resultCh <- writeContextResult{n: n, err: err}
}()
select {
case result := <-resultCh:
return result.n, result.err
case <-ctx.Done():
select {
case result := <-resultCh:
return result.n, result.err
default:
return 0, ctx.Err()
}
}
}

153
io_context_test.go Normal file
View File

@ -0,0 +1,153 @@
package stario
import (
"bytes"
"context"
"errors"
"io"
"sync"
"testing"
"time"
)
type chunkedWriter struct {
buf bytes.Buffer
max int
}
func (writer *chunkedWriter) Write(p []byte) (int, error) {
if writer.max <= 0 || len(p) <= writer.max {
return writer.buf.Write(p)
}
return writer.buf.Write(p[:writer.max])
}
type blockingReader struct {
started chan struct{}
release chan struct{}
data []byte
once sync.Once
}
func (reader *blockingReader) Read(p []byte) (int, error) {
reader.once.Do(func() { close(reader.started) })
<-reader.release
n := copy(p, reader.data)
return n, nil
}
type blockingWriter struct {
started chan struct{}
release chan struct{}
once sync.Once
}
func (writer *blockingWriter) Write(p []byte) (int, error) {
writer.once.Do(func() { close(writer.started) })
<-writer.release
return len(p), nil
}
func TestReadFullContext(t *testing.T) {
buf := make([]byte, 5)
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf)
if err != nil {
t.Fatalf("ReadFullContext returned error: %v", err)
}
if n != 5 || string(buf) != "hello" {
t.Fatalf("unexpected payload: n=%d data=%q", n, buf)
}
}
func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) {
buf := make([]byte, 5)
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf)
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected unexpected EOF, got %v", err)
}
if n != 3 || string(buf[:n]) != "hey" {
t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n])
}
}
func TestReadFullContextCanceledWhileBlocked(t *testing.T) {
reader := &blockingReader{
started: make(chan struct{}),
release: make(chan struct{}),
data: []byte("hello"),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
buf := make([]byte, 5)
_, err := ReadFullContext(ctx, reader, buf)
done <- err
}()
<-reader.started
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("ReadFullContext did not return after cancel")
}
close(reader.release)
}
func TestWriteFullContext(t *testing.T) {
writer := &chunkedWriter{max: 2}
n, err := WriteFullContext(context.Background(), writer, []byte("hello"))
if err != nil {
t.Fatalf("WriteFullContext returned error: %v", err)
}
if n != 5 || writer.buf.String() != "hello" {
t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String())
}
}
func TestWriteFullContextCanceledWhileBlocked(t *testing.T) {
writer := &blockingWriter{
started: make(chan struct{}),
release: make(chan struct{}),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
_, err := WriteFullContext(ctx, writer, []byte("hello"))
done <- err
}()
<-writer.started
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("WriteFullContext did not return after cancel")
}
close(writer.release)
}
func TestCopyContext(t *testing.T) {
var dst bytes.Buffer
written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world"))
if err != nil {
t.Fatalf("CopyContext returned error: %v", err)
}
if written != int64(len("hello world")) || dst.String() != "hello world" {
t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String())
}
}

View File

@ -1,10 +1,32 @@
package stario package stario
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"strings"
"testing" "testing"
) )
func installRawInputStub(t *testing.T, input string, signalErr error) {
t.Helper()
prevFactory := rawTerminalSessionFactory
prevSignalHandler := inputSignalHandler
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
return &rawTerminalSession{
reader: bufio.NewReader(strings.NewReader(input)),
redrawHint: promptRedrawHint(hint),
}, nil
}
inputSignalHandler = func(sigtype rune) error {
return signalErr
}
t.Cleanup(func() {
rawTerminalSessionFactory = prevFactory
inputSignalHandler = prevSignalHandler
})
}
func TestPromptRedrawHint(t *testing.T) { func TestPromptRedrawHint(t *testing.T) {
got := promptRedrawHint("头部提示\n 中文确认: ") got := promptRedrawHint("头部提示\n 中文确认: ")
if got != "中文确认:" { if got != "中文确认:" {
@ -19,6 +41,81 @@ func TestStringDisplayWidth(t *testing.T) {
} }
} }
func TestRawEchoRenderUnitPlainRune(t *testing.T) {
text, width, ok := rawEchoRenderUnit('中', "", 0)
if !ok {
t.Fatal("expected plain wide rune to use fast path")
}
if text != "中" || width != 2 {
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "中", 2)
}
}
func TestRawEchoRenderUnitMaskedRune(t *testing.T) {
text, width, ok := rawEchoRenderUnit('a', "[]", stringDisplayWidth("[]"))
if !ok {
t.Fatal("expected masked rune to use fast path")
}
if text != "[]" || width != 2 {
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "[]", 2)
}
}
func TestRawEchoRenderUnitFallsBackForControlRune(t *testing.T) {
if _, _, ok := rawEchoRenderUnit('\x00', "", 0); ok {
t.Fatal("expected control rune to fall back to full redraw")
}
}
func TestSignalInputResultExitSuppressesError(t *testing.T) {
got := signalInputResult(rawInputSignalExit, ErrSignalInterrupt)
if got.err != nil {
t.Fatalf("expected nil error, got %v", got.err)
}
if got.msg != "" {
t.Fatalf("expected empty message, got %q", got.msg)
}
}
func TestSignalInputResultReturnErrorPreservesSignal(t *testing.T) {
got := signalInputResult(rawInputSignalReturnError, ErrSignalInterrupt)
if !errors.Is(got.err, ErrSignalInterrupt) {
t.Fatalf("expected signal error, got %v", got.err)
}
if got.msg != "" {
t.Fatalf("expected empty message, got %q", got.msg)
}
}
func TestPasswdSuppressesSignalError(t *testing.T) {
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
got := Passwd("", "fallback")
if got.err != nil {
t.Fatalf("expected nil error, got %v", got.err)
}
if got.msg != "" {
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
}
}
func TestPasswdResponseSignalPreservesSignalError(t *testing.T) {
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
got := PasswdResponseSignal("", "fallback")
if !errors.Is(got.err, ErrSignalInterrupt) {
t.Fatalf("expected interrupt error, got %v", got.err)
}
if got.msg != "" {
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
}
}
func TestStopUntilEmptyTriggerReturnsAfterFirstKey(t *testing.T) {
installRawInputStub(t, "abc", nil)
if err := StopUntil("", "", false); err != nil {
t.Fatalf("StopUntil returned error: %v", err)
}
}
func TestParseYesNoValue(t *testing.T) { func TestParseYesNoValue(t *testing.T) {
cases := []struct { cases := []struct {
name string name string
@ -40,6 +137,59 @@ func TestParseYesNoValue(t *testing.T) {
} }
} }
func TestSliceFloat32(t *testing.T) {
data := InputMsg{msg: "1.5,2.25", err: nil}
got, err := data.SliceFloat32(",")
if err != nil {
t.Fatalf("SliceFloat32 returned error: %v", err)
}
if len(got) != 2 || got[0] != float32(1.5) || got[1] != float32(2.25) {
t.Fatalf("unexpected float32 slice: %#v", got)
}
}
func TestTypedSliceParsingTrimsTokenWhitespace(t *testing.T) {
ints, err := (InputMsg{msg: "1, 2, 3"}).SliceInt(",")
if err != nil {
t.Fatalf("SliceInt returned error: %v", err)
}
if len(ints) != 3 || ints[0] != 1 || ints[1] != 2 || ints[2] != 3 {
t.Fatalf("unexpected int slice: %#v", ints)
}
bools, err := (InputMsg{msg: "true, false, true"}).SliceBool(",")
if err != nil {
t.Fatalf("SliceBool returned error: %v", err)
}
if len(bools) != 3 || !bools[0] || bools[1] || !bools[2] {
t.Fatalf("unexpected bool slice: %#v", bools)
}
float64s, err := (InputMsg{msg: "1.25, 2.5, 3.75"}).SliceFloat64(",")
if err != nil {
t.Fatalf("SliceFloat64 returned error: %v", err)
}
if len(float64s) != 3 || float64s[0] != 1.25 || float64s[1] != 2.5 || float64s[2] != 3.75 {
t.Fatalf("unexpected float64 slice: %#v", float64s)
}
}
func TestAdvanceTriggerIndexHandlesOverlap(t *testing.T) {
trigger := []rune("aba")
prefix := buildTriggerPrefixTable(trigger)
index := 0
complete := false
for _, r := range []rune("aaba") {
index, complete = advanceTriggerIndex(trigger, prefix, index, r)
if complete {
break
}
}
if !complete {
t.Fatal("expected overlapped trigger to complete")
}
}
func Test_Slice(t *testing.T) { func Test_Slice(t *testing.T) {
var data = InputMsg{ var data = InputMsg{
msg: "true,false,true,true,false,0,1,hello", msg: "true,false,true,true,false,0,1,hello",

71
pipe.go Normal file
View File

@ -0,0 +1,71 @@
package stario
import "io"
// StarPipeReader is the read side returned by NewStarPipe.
type StarPipeReader struct {
buf *StarBuffer
}
// StarPipeWriter is the write side returned by NewStarPipe.
type StarPipeWriter struct {
buf *StarBuffer
}
// NewStarPipe creates a buffered in-memory pipe backed by StarBuffer.
//
// The writer side uses Close to signal a graceful end-of-stream and Abort to
// fail both sides immediately.
func NewStarPipe(capacity uint64) (*StarPipeReader, *StarPipeWriter, error) {
buf, err := NewStarBuffer(capacity)
if err != nil {
return nil, nil, err
}
return &StarPipeReader{buf: buf}, &StarPipeWriter{buf: buf}, nil
}
// Read reads buffered bytes from the pipe.
func (reader *StarPipeReader) Read(p []byte) (int, error) {
if reader == nil || reader.buf == nil {
return 0, io.ErrClosedPipe
}
return reader.buf.Read(p)
}
// Close aborts the pipe from the read side and wakes blocked writers.
func (reader *StarPipeReader) Close() error {
if reader == nil || reader.buf == nil {
return io.ErrClosedPipe
}
return reader.buf.Abort()
}
// Abort aborts the pipe from the read side and wakes blocked writers.
func (reader *StarPipeReader) Abort() error {
return reader.Close()
}
// Write writes bytes into the pipe buffer.
func (writer *StarPipeWriter) Write(p []byte) (int, error) {
if writer == nil || writer.buf == nil {
return 0, io.ErrClosedPipe
}
return writer.buf.Write(p)
}
// Close gracefully closes the write side. Buffered bytes remain readable until
// drained.
func (writer *StarPipeWriter) Close() error {
if writer == nil || writer.buf == nil {
return io.ErrClosedPipe
}
return writer.buf.Close()
}
// Abort aborts the pipe immediately.
func (writer *StarPipeWriter) Abort() error {
if writer == nil || writer.buf == nil {
return io.ErrClosedPipe
}
return writer.buf.Abort()
}

53
pipe_test.go Normal file
View File

@ -0,0 +1,53 @@
package stario
import (
"bytes"
"errors"
"io"
"testing"
)
func TestStarPipeRoundTrip(t *testing.T) {
reader, writer, err := NewStarPipe(16)
if err != nil {
t.Fatal(err)
}
want := []byte("hello world")
go func() {
_, _ = writer.Write(want)
_ = writer.Close()
}()
got, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll failed: %v", err)
}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected payload: got %q want %q", got, want)
}
}
func TestStarPipeReaderCloseAbortsWriter(t *testing.T) {
reader, writer, err := NewStarPipe(4)
if err != nil {
t.Fatal(err)
}
if err := reader.Close(); err != nil {
t.Fatal(err)
}
if _, err := writer.Write([]byte("x")); !errors.Is(err, ErrStarBufferClosed) {
t.Fatalf("unexpected writer error after reader close: %v", err)
}
}
func TestStarPipeNilEndsReportClosedPipe(t *testing.T) {
var reader *StarPipeReader
var writer *StarPipeWriter
if _, err := reader.Read(make([]byte, 1)); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("unexpected nil reader error: %v", err)
}
if _, err := writer.Write([]byte("x")); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("unexpected nil writer error: %v", err)
}
}

325
que.go
View File

@ -1,325 +0,0 @@
package stario
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"sync"
"time"
)
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
maxLength uint32
count int64
Encode bool
msgID uint16
msgPool chan MsgQueue
unFinMsg sync.Map
lastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
var q StarQueue
q.Encode = false
q.count = count
q.maxLength = maxMsgLength
q.msgPool = make(chan MsgQueue, count)
if ctx == nil {
q.ctx, q.cancel = context.WithCancel(context.Background())
} else {
q.ctx, q.cancel = context.WithCancel(ctx)
}
q.duration = 0
return &q
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count, 0)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (q *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
q.msgID++
if q.Encode {
src = q.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (q *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
q.msgID++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (q *StarQueue) push2list(msg MsgQueue) {
q.msgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := q.unFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
q.unFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return q.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
q.unFinMsg.Delete(conn)
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
}
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
q.unFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
q.unFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//q.restoreMu.Lock()
q.push2list(storeMsg)
//q.restoreMu.Unlock()
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
q.push2list(storeMsg)
q.unFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return q.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
q.unFinMsg.Store(conn, &lastMsg)
return q.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (q *StarQueue) Restore() (MsgQueue, error) {
if q.duration.Seconds() == 0 {
q.duration = 86400 * time.Second
}
for {
select {
case <-q.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(q.duration):
if q.duration != 0 {
return MsgQueue{}, ErrDeadlineExceeded
}
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
// 兼容性修改
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
return q.Restore()
}
// Stop 立即停止Restore
func (q *StarQueue) Stop() {
q.cancel()
}
// RestoreDuration Restore最大超时时间
func (q *StarQueue) RestoreDuration(tm time.Duration) {
q.duration = tm
}
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
return q.msgPool
}

80
que_benchmark_test.go Normal file
View File

@ -0,0 +1,80 @@
package stario
import (
"io"
"testing"
)
func BenchmarkQueueBuildMessage64KiB(b *testing.B) {
que := NewQueue()
payload := make([]byte, 64*1024)
b.ReportAllocs()
b.SetBytes(int64(len(payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = que.BuildMessage(payload)
}
}
func BenchmarkQueueWriteFrame64KiB(b *testing.B) {
que := NewQueue()
payload := make([]byte, 64*1024)
b.ReportAllocs()
b.SetBytes(int64(len(payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := que.WriteFrame(io.Discard, payload); err != nil {
b.Fatalf("WriteFrame failed: %v", err)
}
}
}
func BenchmarkQueueWriteFrameBuffers64KiB(b *testing.B) {
que := NewQueue()
payload := make([]byte, 64*1024)
b.ReportAllocs()
b.SetBytes(int64(len(payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := que.WriteFrameBuffers(io.Discard, payload); err != nil {
b.Fatalf("WriteFrameBuffers failed: %v", err)
}
}
}
func BenchmarkQueueWriteFramesBuffers4x64KiB(b *testing.B) {
que := NewQueue()
payloads := [][]byte{
make([]byte, 64*1024),
make([]byte, 64*1024),
make([]byte, 64*1024),
make([]byte, 64*1024),
}
b.ReportAllocs()
b.SetBytes(int64(len(payloads) * len(payloads[0])))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := que.WriteFramesBuffers(io.Discard, payloads...); err != nil {
b.Fatalf("WriteFramesBuffers failed: %v", err)
}
}
}
func BenchmarkQueueParseRestoreHello(b *testing.B) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
if len(frame) == 0 {
b.Fatal("BuildMessage returned empty frame")
}
b.ReportAllocs()
b.SetBytes(int64(len(frame)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := que.ParseMessage(frame, "bench"); err != nil {
b.Fatal(err)
}
if _, err := que.Restore(); err != nil {
b.Fatal(err)
}
}
}

34
que_bytes.go Normal file
View File

@ -0,0 +1,34 @@
package stario
import "encoding/binary"
// Uint32ToByte 4位uint32转byte。
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
binary.BigEndian.PutUint32(res, src)
return res
}
// ByteToUint32 byte转4位uint32。
func ByteToUint32(src []byte) uint32 {
return binary.BigEndian.Uint32(src)
}
// Uint16ToByte 2位uint16转byte。
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
binary.BigEndian.PutUint16(res, src)
return res
}
// ByteToUint16 用于byte转uint16。
func ByteToUint16(src []byte) uint16 {
return binary.BigEndian.Uint16(src)
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil
}
return append([]byte(nil), src...)
}

208
que_frame.go Normal file
View File

@ -0,0 +1,208 @@
package stario
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// BuildMessage builds one full frame and panics if the payload is too large to
// fit in the framing format.
//
// New code should prefer BuildMessageErr when it needs an explicit error path.
func (q *StarQueue) BuildMessage(src []byte) []byte {
frame, err := q.BuildMessageErr(src)
if err != nil {
panic(err)
}
return frame
}
// BuildMessageErr builds one full frame and returns an explicit error when the
// payload exceeds the 32-bit frame length.
func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) {
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return nil, err
}
header := q.BuildHeader(length)
frame := make([]byte, 0, len(header)+len(payload))
frame = append(frame, header...)
frame = append(frame, payload...)
return frame, nil
}
// WriteFrame writes one framed payload directly to w without building an
// intermediate full-frame slice first.
func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error {
if w == nil {
return io.ErrClosedPipe
}
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
var header [queHeaderSize]byte
writeHeaderBytes(header[:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
if err := writeFull(w, header[:]); err != nil {
return err
}
return writeFull(w, payload)
}
// WriteFrameBuffers writes one framed payload using net.Buffers so callers can
// opt into gather writes when the underlying writer supports it well.
func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error {
if w == nil {
return io.ErrClosedPipe
}
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
var header [queHeaderSize]byte
writeHeaderBytes(header[:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
buffers := net.Buffers{header[:], payload}
_, err = buffers.WriteTo(w)
return err
}
// WriteFramesBuffers writes multiple framed payloads using one net.Buffers
// batch so callers can reduce write calls on stream transports.
func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error {
if w == nil {
return io.ErrClosedPipe
}
if len(payloads) == 0 {
return nil
}
buffers := make(net.Buffers, 0, len(payloads)*2)
headers := make([][queHeaderSize]byte, len(payloads))
for i, src := range payloads {
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
writeHeaderBytes(headers[i][:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
buffers = append(buffers, headers[i][:], payload)
}
_, err := buffers.WriteTo(w)
return err
}
// BuildHeader 生成编码后的Header用于发送。
func (q *StarQueue) BuildHeader(length uint32) []byte {
return buildHeaderBytes(queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
}
func buildHeaderBytes(header queHeader) []byte {
buf := make([]byte, queHeaderSize)
writeHeaderBytes(buf, header)
return buf
}
func writeHeaderBytes(dst []byte, header queHeader) {
if len(dst) < queHeaderSize {
return
}
copy(dst[:queMagicSize], queMagic)
binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length)
dst[12] = header.Version
dst[13] = header.Flags
}
func payloadSizeToUint32(size uint64) (uint32, error) {
const maxFramePayload = ^uint32(0)
if size > uint64(maxFramePayload) {
return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload))
}
return uint32(size), nil
}
func writeFull(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if n > 0 {
data = data[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) {
if len(src) < queHeaderSize {
return queHeader{}, ErrQueueDataFormat
}
if !equalMagic(src[:queMagicSize]) {
return queHeader{}, ErrQueueDataFormat
}
header := queHeader{
Length: ByteToUint32(src[queMagicSize : queMagicSize+4]),
Version: src[12],
Flags: src[13],
}
if header.Version != queVersionV1 {
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version)
}
if header.Flags != queSupportedFlags {
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags)
}
if maxLength != 0 && header.Length > maxLength {
return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength)
}
return header, nil
}
func equalMagic(src []byte) bool {
if len(src) != queMagicSize {
return false
}
for i, b := range src {
if b != queMagic[i] {
return false
}
}
return true
}

237
que_parse.go Normal file
View File

@ -0,0 +1,237 @@
package stario
import (
"bytes"
"fmt"
"reflect"
)
// ParseMessage 用于解析收到的msg信息。
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn, false, func(payload []byte) error {
return q.push2list(MsgQueue{
Msg: payload,
Conn: conn,
})
})
}
// ParseMessageView parses frames and exposes each payload to fn without
// forcing StarQueue to clone it first.
//
// The provided payload is only valid during the callback. If q uses the legacy
// DecodeFunc path, StarQueue still has to allocate a decoded payload first.
func (q *StarQueue) ParseMessageView(msg []byte, conn interface{}, fn func(FrameView) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
return q.parseMessage(msg, conn, true, func(payload []byte) error {
return fn(FrameView{
Payload: payload,
Conn: conn,
})
})
}
// ParseMessageOwned parses frames and emits owned payload copies to fn without
// routing them through RestoreChan.
//
// Compared with ParseMessage, this keeps StarQueue state handling the same but
// lets callers decide how to dispatch parsed messages themselves.
func (q *StarQueue) ParseMessageOwned(msg []byte, conn interface{}, fn func(MsgQueue) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
frames := make([]MsgQueue, 0, 1)
parseErr := q.parseMessage(msg, conn, false, func(payload []byte) error {
frames = append(frames, MsgQueue{
Msg: payload,
Conn: conn,
})
return nil
})
for _, frame := range frames {
if err := fn(frame); err != nil {
if parseErr != nil {
return fmt.Errorf("%v: %w", parseErr, err)
}
return err
}
}
return parseErr
}
func (q *StarQueue) parseMessage(msg []byte, conn interface{}, borrowPayload bool, emit func([]byte) error) error {
state, err := q.connState(conn)
if err != nil {
return err
}
var firstErr error
for {
payload, ok, err := q.nextPayload(state, conn, msg, borrowPayload)
msg = nil
if err != nil && firstErr == nil {
firstErr = err
}
if !ok {
break
}
if err := emit(payload); err != nil {
if firstErr != nil {
return fmt.Errorf("%v: %w", firstErr, err)
}
return err
}
}
return firstErr
}
func (q *StarQueue) nextPayload(state *queConnState, conn interface{}, msg []byte, borrowPayload bool) ([]byte, bool, error) {
state.mu.Lock()
defer state.mu.Unlock()
if len(msg) != 0 {
state.buf = append(state.buf, msg...)
}
var firstErr error
for {
synced, err := syncFrameStart(&state.buf)
if err != nil && firstErr == nil {
firstErr = err
}
if !synced {
if len(state.buf) == 0 {
q.states.Delete(conn)
}
return nil, false, firstErr
}
if len(state.buf) < queHeaderSize {
return nil, false, firstErr
}
header, err := parseHeaderBytes(state.buf[:queHeaderSize], q.maxLength)
if err != nil {
if firstErr == nil {
firstErr = err
}
state.buf = shrinkBuffer(state.buf[1:])
continue
}
frameLen := queHeaderSize + int(header.Length)
if len(state.buf) < frameLen {
return nil, false, firstErr
}
payload, rest := extractPayload(state.buf, frameLen, borrowPayload && !(q.Encode && q.DecodeFunc != nil))
state.buf = rest
if q.Encode && q.DecodeFunc != nil {
payload = q.DecodeFunc(payload)
}
if len(state.buf) == 0 {
q.states.Delete(conn)
}
return payload, true, firstErr
}
}
func extractPayload(buf []byte, frameLen int, borrowPayload bool) ([]byte, []byte) {
payload := buf[queHeaderSize:frameLen]
if !borrowPayload {
return cloneBytes(payload), shrinkBuffer(buf[frameLen:])
}
if frameLen == len(buf) {
return payload, nil
}
return payload, cloneBytes(buf[frameLen:])
}
func (q *StarQueue) push2list(msg MsgQueue) error {
select {
case <-q.ctx.Done():
return q.ctx.Err()
default:
}
q.sendMu.RLock()
defer q.sendMu.RUnlock()
select {
case <-q.ctx.Done():
return q.ctx.Err()
case q.msgPool <- msg:
return nil
}
}
func validateConnKey(conn interface{}) error {
if conn == nil {
return ErrQueueConnKeyNil
}
typ := reflect.TypeOf(conn)
if typ != nil && !typ.Comparable() {
return ErrQueueConnKeyInvalid
}
return nil
}
func (q *StarQueue) connState(conn interface{}) (*queConnState, error) {
if err := validateConnKey(conn); err != nil {
return nil, err
}
state, _ := q.states.LoadOrStore(conn, &queConnState{})
return state.(*queConnState), nil
}
func syncFrameStart(buf *[]byte) (bool, error) {
if len(*buf) == 0 {
return false, nil
}
if len(*buf) >= queMagicSize && equalMagic((*buf)[:queMagicSize]) {
return true, nil
}
idx := bytes.Index(*buf, queMagic)
if idx == 0 {
return true, nil
}
if idx > 0 {
*buf = cloneBytes((*buf)[idx:])
return true, ErrQueueDataFormat
}
keep := trailingMagicPrefixLen(*buf)
if keep == len(*buf) {
return false, nil
}
if keep > 0 {
*buf = cloneBytes((*buf)[len(*buf)-keep:])
return false, ErrQueueDataFormat
}
*buf = (*buf)[:0]
return false, ErrQueueDataFormat
}
func trailingMagicPrefixLen(buf []byte) int {
max := len(buf)
if max > queMagicSize-1 {
max = queMagicSize - 1
}
for keep := max; keep > 0; keep-- {
if bytes.Equal(buf[len(buf)-keep:], queMagic[:keep]) {
return keep
}
}
return 0
}
func shrinkBuffer(buf []byte) []byte {
if len(buf) == 0 {
return nil
}
if cap(buf) > len(buf)*4 {
return cloneBytes(buf)
}
return buf
}

80
que_runtime.go Normal file
View File

@ -0,0 +1,80 @@
package stario
import (
"os"
"time"
)
func (q *StarQueue) closedRestoreErr() error {
if err := q.ctx.Err(); err != nil {
return err
}
return os.ErrClosed
}
// Restore blocks until one message is available, the queue is stopped, or the
// configured timeout expires.
func (q *StarQueue) Restore() (MsgQueue, error) {
if q.duration <= 0 {
select {
case <-q.ctx.Done():
return MsgQueue{}, q.ctx.Err()
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, q.closedRestoreErr()
}
return data, nil
}
}
timer := time.NewTimer(q.duration)
defer timer.Stop()
select {
case <-q.ctx.Done():
return MsgQueue{}, q.ctx.Err()
case <-timer.C:
return MsgQueue{}, ErrDeadlineExceeded
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, q.closedRestoreErr()
}
return data, nil
}
}
// RestoreOne 获取收到的一个信息。
// 兼容性修改。
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
return q.Restore()
}
// Stop cancels the queue runtime.
//
// After Stop returns, Restore unblocks with context.Canceled and RestoreChan is
// eventually closed.
func (q *StarQueue) Stop() {
q.cancel()
q.shutdown()
}
func (q *StarQueue) shutdown() {
q.stopOnce.Do(func() {
q.sendMu.Lock()
close(q.msgPool)
q.sendMu.Unlock()
})
}
// RestoreDuration sets the Restore timeout. A non-positive duration means wait
// forever until a message arrives or Stop is called.
func (q *StarQueue) RestoreDuration(tm time.Duration) {
q.duration = tm
}
// RestoreChan exposes the parsed message stream.
//
// The returned channel is closed after Stop or when the queue context is
// canceled. New code should still prefer Restore when it needs timeout/error
// classification instead of a plain stream.
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
return q.msgPool
}

View File

@ -1,42 +1,520 @@
package stario package stario
import ( import (
"fmt" "bytes"
"context"
"errors"
"io"
"testing" "testing"
"time" "time"
) )
func Test_QueSpeed(t *testing.T) { func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) {
que := NewQueueWithCount(0) que := NewQueue()
stop := make(chan struct{}, 1) frame := que.BuildMessage([]byte("hello"))
que.RestoreDuration(time.Second * 10)
var count int64 if len(frame) != queHeaderSize+5 {
t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5)
}
if !bytes.Equal(frame[:queMagicSize], queMagic) {
t.Fatalf("unexpected magic: %v", frame[:queMagicSize])
}
if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 {
t.Fatalf("unexpected payload length: got %d want 5", got)
}
if frame[12] != queVersionV1 {
t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1)
}
if frame[13] != queSupportedFlags {
t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags)
}
if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) {
t.Fatalf("unexpected payload: %q", frame[queHeaderSize:])
}
}
func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) {
que := NewQueue()
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrame mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) {
que := NewQueue()
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrameBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) {
que := NewQueue()
payloads := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("batch"),
}
var want []byte
for _, payload := range payloads {
want = append(want, que.BuildMessage(payload)...)
}
var buf bytes.Buffer
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrameBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
payloads := [][]byte{
[]byte("hello"),
[]byte("batch"),
}
var want []byte
for _, payload := range payloads {
want = append(want, que.BuildMessage(payload)...)
}
var buf bytes.Buffer
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueParseMessageSplitAcrossCalls(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
if err := que.ParseMessage(frame[:3], "split"); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessage(frame[3:11], "split"); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessage(frame[11:], "split"); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
select {
case data := <-que.RestoreChan():
if data.ID != 0 {
t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID)
}
if data.Conn != "split" {
t.Fatalf("unexpected conn: %#v", data.Conn)
}
if !bytes.Equal(data.Msg, []byte("hello")) {
t.Fatalf("unexpected payload: %q", data.Msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("did not restore parsed frame")
}
}
func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got [][]byte
handler := func(view FrameView) error {
got = append(got, cloneBytes(view.Payload))
if view.Conn != "split-view" {
t.Fatalf("unexpected conn: %#v", view.Conn)
}
return nil
}
if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
if len(got) != 1 {
t.Fatalf("unexpected frame count: got %d want 1", len(got))
}
if !bytes.Equal(got[0], []byte("hello")) {
t.Fatalf("unexpected payload: %q", got[0])
}
}
func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got []MsgQueue
handler := func(msg MsgQueue) error {
got = append(got, MsgQueue{
Msg: cloneBytes(msg.Msg),
Conn: msg.Conn,
})
return nil
}
if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
if len(got) != 1 {
t.Fatalf("unexpected frame count: got %d want 1", len(got))
}
if got[0].Conn != "split-owned" {
t.Fatalf("unexpected conn: %#v", got[0].Conn)
}
if !bytes.Equal(got[0].Msg, []byte("hello")) {
t.Fatalf("unexpected payload: %q", got[0].Msg)
}
select {
case msg := <-que.RestoreChan():
t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg)
default:
}
}
func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(append([]byte("junk"), frame...), "garbage")
if !errors.Is(err, ErrQueueDataFormat) {
t.Fatalf("expected data format error, got %v", err)
}
select {
case data := <-que.RestoreChan():
if !bytes.Equal(data.Msg, []byte("hello")) {
t.Fatalf("unexpected payload after resync: %q", data.Msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("did not restore frame after skipping garbage")
}
}
func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got []byte
err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error {
got = cloneBytes(view.Payload)
return nil
})
if !errors.Is(err, ErrQueueDataFormat) {
t.Fatalf("expected data format error, got %v", err)
}
if !bytes.Equal(got, []byte("hello")) {
t.Fatalf("unexpected payload after resync: %q", got)
}
}
func TestQueueParseMessageViewNilHandler(t *testing.T) {
que := NewQueue()
err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil)
if !errors.Is(err, ErrQueueFrameHandlerNil) {
t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil)
}
}
func TestQueueParseMessageOwnedNilHandler(t *testing.T) {
que := NewQueue()
err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil)
if !errors.Is(err, ErrQueueFrameHandlerNil) {
t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil)
}
}
func TestQueueWriteFrameNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFrame(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueWriteFrameBuffersNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFrameBuffers(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueWriteFramesBuffersNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFramesBuffers(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
frame[12] = 2
err := que.ParseMessage(frame, "version")
if !errors.Is(err, ErrQueueUnsupportedVersion) {
t.Fatalf("expected unsupported version error, got %v", err)
}
select {
case data := <-que.RestoreChan():
t.Fatalf("unexpected restored frame: %#v", data)
default:
}
}
func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) {
que := NewQueueCtx(nil, 1, 4)
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, "large")
if !errors.Is(err, ErrQueueMessageTooLarge) {
t.Fatalf("expected message too large error, got %v", err)
}
select {
case data := <-que.RestoreChan():
t.Fatalf("unexpected restored frame: %#v", data)
default:
}
}
func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, []byte("not-comparable"))
if !errors.Is(err, ErrQueueConnKeyInvalid) {
t.Fatalf("expected invalid conn key error, got %v", err)
}
}
func TestQueueParseMessageRejectsNilConnKey(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, nil)
if !errors.Is(err, ErrQueueConnKeyNil) {
t.Fatalf("expected nil conn key error, got %v", err)
}
}
func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) {
_, err := payloadSizeToUint32(uint64(^uint32(0)) + 1)
if !errors.Is(err, ErrQueueMessageTooLarge) {
t.Fatalf("expected message too large error, got %v", err)
}
}
func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) {
que := NewQueueWithCount(1)
que.RestoreDuration(0)
type restoreResult struct {
msg MsgQueue
err error
}
resultCh := make(chan restoreResult, 1)
go func() { go func() {
for { msg, err := que.Restore()
select { resultCh <- restoreResult{msg: msg, err: err}
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}() }()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select { select {
case <-stoped: case result := <-resultCh:
fmt.Println(count, cp) t.Fatalf("Restore returned too early: %#v", result)
stop <- struct{}{} case <-time.After(50 * time.Millisecond):
return }
default:
que.ParseMessage(data, "lala") if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil {
cp++ t.Fatalf("ParseMessage failed: %v", err)
}
select {
case result := <-resultCh:
if result.err != nil {
t.Fatalf("Restore returned error: %v", result.err)
}
if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) {
t.Fatalf("unexpected restore result: %#v", result.msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Restore did not return after message arrival")
} }
} }
func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) {
que := NewQueue()
resultCh := make(chan error, 1)
go func() {
_, err := que.Restore()
resultCh <- err
}()
que.Stop()
select {
case err := <-resultCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Restore did not return after Stop")
}
}
func TestQueueRestoreChanClosesOnStop(t *testing.T) {
que := NewQueue()
resultCh := make(chan bool, 1)
go func() {
_, ok := <-que.RestoreChan()
resultCh <- ok
}()
que.Stop()
select {
case ok := <-resultCh:
if ok {
t.Fatal("expected RestoreChan to close after Stop")
}
case <-time.After(200 * time.Millisecond):
t.Fatal("RestoreChan did not close after Stop")
}
}
func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
que := NewQueueCtx(ctx, 1, 0)
resultCh := make(chan bool, 1)
go func() {
_, ok := <-que.RestoreChan()
resultCh <- ok
}()
cancel()
select {
case ok := <-resultCh:
if ok {
t.Fatal("expected RestoreChan to close after context cancel")
}
case <-time.After(200 * time.Millisecond):
t.Fatal("RestoreChan did not close after context cancel")
}
}
func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) {
que := NewQueueWithCount(1)
if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil {
t.Fatalf("ParseMessage first failed: %v", err)
}
errCh := make(chan error, 1)
go func() {
errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full")
}()
select {
case err := <-errCh:
t.Fatalf("ParseMessage returned before Stop: %v", err)
case <-time.After(50 * time.Millisecond):
}
que.Stop()
select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("ParseMessage did not return after Stop")
}
}
func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) {
que := NewQueue()
reentered := false
err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error {
if !bytes.Equal(view.Payload, []byte("outer")) {
t.Fatalf("unexpected outer payload: %q", view.Payload)
}
return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error {
reentered = true
if !bytes.Equal(inner.Payload, []byte("inner")) {
t.Fatalf("unexpected inner payload: %q", inner.Payload)
}
return nil
})
})
if err != nil {
t.Fatalf("ParseMessageView failed: %v", err)
}
if !reentered {
t.Fatal("expected reentrant ParseMessageView to run")
}
} }

101
que_types.go Normal file
View File

@ -0,0 +1,101 @@
package stario
import (
"context"
"errors"
"sync"
"time"
)
var ErrDeadlineExceeded = errors.New("deadline exceeded")
var ErrQueueDataFormat = errors.New("data format error")
var ErrQueueUnsupportedVersion = errors.New("unsupported frame version")
var ErrQueueUnsupportedFlags = errors.New("unsupported frame flags")
var ErrQueueMessageTooLarge = errors.New("message too large")
var ErrQueueFrameHandlerNil = errors.New("frame handler is nil")
var ErrQueueConnKeyInvalid = errors.New("queue conn key must be comparable")
var ErrQueueConnKeyNil = errors.New("queue conn key must not be nil")
const (
queMagicSize = 8
queHeaderSize = 14
queVersionV1 = 1
queSupportedFlags = 0
)
// 识别头
var queMagic = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位。
type MsgQueue struct {
// Deprecated: frame-level IDs are no longer emitted by StarQueue v2 framing.
ID uint16
Msg []byte
Conn interface{}
}
// FrameView exposes a parsed payload without forcing StarQueue to clone it.
//
// The payload is only valid during the ParseMessageView callback. Callers must
// copy it if they need to keep it after the callback returns.
type FrameView struct {
Payload []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发。
type StarQueue struct {
maxLength uint32
msgPool chan MsgQueue
states sync.Map
ctx context.Context
cancel context.CancelFunc
duration time.Duration
sendMu sync.RWMutex
stopOnce sync.Once
// Deprecated: new code should keep StarQueue focused on framing only.
Encode bool
// Deprecated: new code should keep StarQueue focused on framing only.
EncodeFunc func([]byte) []byte
// Deprecated: new code should keep StarQueue focused on framing only.
DecodeFunc func([]byte) []byte
}
type queHeader struct {
Length uint32
Version uint8
Flags uint8
}
type queConnState struct {
mu sync.Mutex
buf []byte
}
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
if count < 0 {
panic("stario: negative queue count")
}
q := &StarQueue{
maxLength: maxMsgLength,
msgPool: make(chan MsgQueue, count),
}
if ctx == nil {
q.ctx, q.cancel = context.WithCancel(context.Background())
} else {
q.ctx, q.cancel = context.WithCancel(ctx)
}
context.AfterFunc(q.ctx, q.shutdown)
return q
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count, 0)
}
// NewQueue 建立一个新消息队列。
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}

44
signal_error.go Normal file
View File

@ -0,0 +1,44 @@
package stario
import "errors"
var (
ErrSignalInterrupt = errors.New("interrupt")
ErrSignalStop = errors.New("stop")
ErrSignalQuit = errors.New("quit")
)
type inputSignalError struct {
msg string
cause error
}
func (e *inputSignalError) Error() string {
return e.msg
}
func (e *inputSignalError) Unwrap() error {
return e.cause
}
func signalErrorForType(sigtype rune) error {
switch sigtype {
case 0x03:
return &inputSignalError{
msg: "SIGNAL SIGINT RECIVED",
cause: ErrSignalInterrupt,
}
case 0x1a:
return &inputSignalError{
msg: "SIGNAL SIGSTOP RECIVED",
cause: ErrSignalStop,
}
case 0x1c:
return &inputSignalError{
msg: "SIGNAL SIGQUIT RECIVED",
cause: ErrSignalQuit,
}
default:
return nil
}
}

50
signal_error_test.go Normal file
View File

@ -0,0 +1,50 @@
package stario
import (
"errors"
"testing"
)
func TestSignalErrorForType(t *testing.T) {
cases := []struct {
name string
sig rune
msg string
want error
}{
{name: "interrupt", sig: 0x03, msg: "SIGNAL SIGINT RECIVED", want: ErrSignalInterrupt},
{name: "stop", sig: 0x1a, msg: "SIGNAL SIGSTOP RECIVED", want: ErrSignalStop},
{name: "quit", sig: 0x1c, msg: "SIGNAL SIGQUIT RECIVED", want: ErrSignalQuit},
}
for _, tc := range cases {
err := signalErrorForType(tc.sig)
if err == nil {
t.Fatalf("%s: expected non-nil error", tc.name)
}
if err.Error() != tc.msg {
t.Fatalf("%s: unexpected error text: got %q want %q", tc.name, err.Error(), tc.msg)
}
if !errors.Is(err, tc.want) {
t.Fatalf("%s: errors.Is mismatch: got %v want %v", tc.name, err, tc.want)
}
}
}
func TestSignalErrorForUnknownType(t *testing.T) {
if err := signalErrorForType('x'); err != nil {
t.Fatalf("expected nil error for unknown signal, got %v", err)
}
}
func TestSignalReturnsTypedError(t *testing.T) {
err := signal(0x03)
if !errors.Is(err, ErrSignalInterrupt) {
t.Fatalf("expected interrupt error, got %v", err)
}
}
func TestSignalReturnsNilForUnknownType(t *testing.T) {
if err := signal('x'); err != nil {
t.Fatalf("expected nil error for unknown signal, got %v", err)
}
}

View File

@ -3,20 +3,6 @@
package stario package stario
import (
"errors"
)
func signal(sigtype rune) error { func signal(sigtype rune) error {
//todo: use win32api call signal return signalErrorForType(sigtype)
switch sigtype {
case 0x03:
return errors.New("SIGNAL SIGINT RECIVED")
case 0x1a:
return errors.New("SIGNAL SIGSTOP RECIVED")
case 0x1c:
return errors.New("SIGNAL SIGQUIT RECIVED")
default:
return nil
}
} }

View File

@ -3,24 +3,6 @@
package stario package stario
import (
"errors"
"os"
"syscall"
)
func signal(sigtype rune) error { func signal(sigtype rune) error {
switch sigtype { return signalErrorForType(sigtype)
case 0x03:
syscall.Kill(os.Getpid(), syscall.SIGINT)
return errors.New("SIGNAL SIGINT RECIVED")
case 0x1a:
syscall.Kill(os.Getpid(), syscall.SIGSTOP)
return errors.New("SIGNAL SIGSTOP RECIVED")
case 0x1c:
syscall.Kill(os.Getpid(), syscall.SIGQUIT)
return errors.New("SIGNAL SIGQUIT RECIVED")
default:
return nil
}
} }

123
starring.go Normal file
View File

@ -0,0 +1,123 @@
package stario
import (
"container/list"
"errors"
"sync"
)
var ErrStarRingInvalidCapacity = errors.New("star ring capacity must be greater than zero")
// StarRing keeps the newest bytes in memory with fixed capacity.
// It never blocks writers: when capacity is exceeded, oldest bytes are dropped.
type StarRing struct {
mu sync.RWMutex
cap int
size int
chunks list.List // each element: []byte
}
func NewStarRing(capacity int) (*StarRing, error) {
if capacity <= 0 {
return nil, ErrStarRingInvalidCapacity
}
return &StarRing{cap: capacity}, nil
}
func (s *StarRing) Capacity() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cap
}
func (s *StarRing) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.size
}
func (s *StarRing) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.size = 0
s.chunks.Init()
}
func (s *StarRing) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
payload := append([]byte(nil), p...)
s.mu.Lock()
defer s.mu.Unlock()
s.writeLocked(payload)
return len(p), nil
}
func (s *StarRing) WriteString(text string) (int, error) {
if len(text) == 0 {
return 0, nil
}
return s.Write([]byte(text))
}
func (s *StarRing) Snapshot() []byte {
s.mu.RLock()
defer s.mu.RUnlock()
out := make([]byte, s.size)
pos := 0
for node := s.chunks.Front(); node != nil; node = node.Next() {
chunk, ok := node.Value.([]byte)
if !ok || len(chunk) == 0 {
continue
}
pos += copy(out[pos:], chunk)
}
return out
}
func (s *StarRing) writeLocked(payload []byte) {
if len(payload) >= s.cap {
payload = payload[len(payload)-s.cap:]
s.chunks.Init()
s.chunks.PushBack(payload)
s.size = len(payload)
return
}
s.chunks.PushBack(payload)
s.size += len(payload)
s.trimLocked()
}
func (s *StarRing) trimLocked() {
if s.size <= s.cap {
return
}
overflow := s.size - s.cap
for overflow > 0 {
front := s.chunks.Front()
if front == nil {
s.size = 0
return
}
head, ok := front.Value.([]byte)
if !ok || len(head) == 0 {
s.chunks.Remove(front)
continue
}
if len(head) <= overflow {
s.chunks.Remove(front)
s.size -= len(head)
overflow -= len(head)
continue
}
trimmed := append([]byte(nil), head[overflow:]...)
front.Value = trimmed
s.size -= overflow
overflow = 0
}
}

104
starring_test.go Normal file
View File

@ -0,0 +1,104 @@
package stario
import (
"bytes"
"testing"
)
func TestNewStarRingRejectsInvalidCapacity(t *testing.T) {
ring, err := NewStarRing(0)
if err != ErrStarRingInvalidCapacity {
t.Fatalf("unexpected error: %v", err)
}
if ring != nil {
t.Fatal("expected nil ring when capacity is invalid")
}
}
func TestStarRingWriteAndSnapshot(t *testing.T) {
ring, err := NewStarRing(16)
if err != nil {
t.Fatal(err)
}
if n, err := ring.Write([]byte("abc")); err != nil || n != 3 {
t.Fatalf("write failed: n=%d err=%v", n, err)
}
if n, err := ring.WriteString("def"); err != nil || n != 3 {
t.Fatalf("write string failed: n=%d err=%v", n, err)
}
got := ring.Snapshot()
if !bytes.Equal(got, []byte("abcdef")) {
t.Fatalf("unexpected snapshot: %q", got)
}
if ring.Len() != 6 {
t.Fatalf("unexpected ring len: %d", ring.Len())
}
if ring.Capacity() != 16 {
t.Fatalf("unexpected ring capacity: %d", ring.Capacity())
}
}
func TestStarRingOverflowDropsOldestBytesPrecisely(t *testing.T) {
ring, err := NewStarRing(10)
if err != nil {
t.Fatal(err)
}
_, _ = ring.WriteString("abcde")
_, _ = ring.WriteString("fghij")
_, _ = ring.WriteString("kl")
got := ring.Snapshot()
want := []byte("cdefghijkl")
if !bytes.Equal(got, want) {
t.Fatalf("unexpected snapshot after overflow: got %q want %q", got, want)
}
if ring.Len() != len(want) {
t.Fatalf("unexpected ring len after overflow: %d", ring.Len())
}
}
func TestStarRingLargeWriteKeepsTail(t *testing.T) {
ring, err := NewStarRing(5)
if err != nil {
t.Fatal(err)
}
_, _ = ring.WriteString("123")
_, _ = ring.WriteString("456789")
got := ring.Snapshot()
want := []byte("56789")
if !bytes.Equal(got, want) {
t.Fatalf("unexpected snapshot after large write: got %q want %q", got, want)
}
}
func TestStarRingSnapshotIsCopy(t *testing.T) {
ring, err := NewStarRing(8)
if err != nil {
t.Fatal(err)
}
_, _ = ring.WriteString("hello")
first := ring.Snapshot()
first[0] = 'H'
second := ring.Snapshot()
if !bytes.Equal(second, []byte("hello")) {
t.Fatalf("snapshot should not expose internal storage: %q", second)
}
}
func TestStarRingReset(t *testing.T) {
ring, err := NewStarRing(8)
if err != nil {
t.Fatal(err)
}
_, _ = ring.WriteString("hello")
ring.Reset()
if ring.Len() != 0 {
t.Fatalf("expected len=0 after reset, got %d", ring.Len())
}
if got := ring.Snapshot(); len(got) != 0 {
t.Fatalf("expected empty snapshot after reset, got %q", got)
}
}

17
sync.go
View File

@ -12,6 +12,10 @@ const (
waitGroupAddModeLoose waitGroupAddModeLoose
) )
// WaitGroup is a concurrency-limited sync.WaitGroup variant.
//
// A zero or negative limit means unlimited concurrency. WaitGroup must not be
// copied after first use.
type WaitGroup struct { type WaitGroup struct {
wg sync.WaitGroup wg sync.WaitGroup
mu sync.Mutex mu sync.Mutex
@ -22,6 +26,7 @@ type WaitGroup struct {
addMode waitGroupAddMode addMode waitGroupAddMode
} }
// NewWaitGroup creates a WaitGroup with the provided concurrency limit.
func NewWaitGroup(maxCount int) WaitGroup { func NewWaitGroup(maxCount int) WaitGroup {
if maxCount < 0 { if maxCount < 0 {
panic("stario: negative max wait count") panic("stario: negative max wait count")
@ -38,6 +43,10 @@ func (w *WaitGroup) init() {
}) })
} }
// Add adjusts the running task count.
//
// Positive deltas may block when the concurrency limit is reached. Negative
// deltas release running slots.
func (w *WaitGroup) Add(delta int) { func (w *WaitGroup) Add(delta int) {
w.init() w.init()
if delta == 0 { if delta == 0 {
@ -87,10 +96,12 @@ func (w *WaitGroup) release(delta int) {
w.cond.Broadcast() w.cond.Broadcast()
} }
// Done releases one running task slot.
func (w *WaitGroup) Done() { func (w *WaitGroup) Done() {
w.Add(-1) w.Add(-1)
} }
// Go runs fn in a goroutine while accounting for the concurrency limit.
func (w *WaitGroup) Go(fn func()) { func (w *WaitGroup) Go(fn func()) {
w.Add(1) w.Add(1)
go func() { go func() {
@ -99,11 +110,13 @@ func (w *WaitGroup) Go(fn func()) {
}() }()
} }
// Wait blocks until all added work has completed.
func (w *WaitGroup) Wait() { func (w *WaitGroup) Wait() {
w.init() w.init()
w.wg.Wait() w.wg.Wait()
} }
// GetMaxWaitNum returns the current concurrency limit.
func (w *WaitGroup) GetMaxWaitNum() int { func (w *WaitGroup) GetMaxWaitNum() int {
w.init() w.init()
w.mu.Lock() w.mu.Lock()
@ -111,6 +124,7 @@ func (w *WaitGroup) GetMaxWaitNum() int {
return w.maxCount return w.maxCount
} }
// SetMaxWaitNum updates the concurrency limit.
func (w *WaitGroup) SetMaxWaitNum(num int) { func (w *WaitGroup) SetMaxWaitNum(num int) {
if num < 0 { if num < 0 {
panic("stario: negative max wait count") panic("stario: negative max wait count")
@ -122,6 +136,8 @@ func (w *WaitGroup) SetMaxWaitNum(num int) {
w.cond.Broadcast() w.cond.Broadcast()
} }
// SetStrictAddMode controls whether Add(n>1) panics or auto-expands the limit
// when the requested batch exceeds the current capacity.
func (w *WaitGroup) SetStrictAddMode(strict bool) { func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.init() w.init()
w.mu.Lock() w.mu.Lock()
@ -134,6 +150,7 @@ func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.cond.Broadcast() w.cond.Broadcast()
} }
// StrictAddMode reports whether strict batch-add behavior is enabled.
func (w *WaitGroup) StrictAddMode() bool { func (w *WaitGroup) StrictAddMode() bool {
w.init() w.init()
w.mu.Lock() w.mu.Lock()

73
terminal.go Normal file
View File

@ -0,0 +1,73 @@
package stario
import (
"context"
"errors"
"os"
"golang.org/x/term"
)
// ErrTerminalNotTTY reports that terminal-only input was requested from a
// non-terminal stdin.
var ErrTerminalNotTTY = errors.New("terminal input requires a tty")
var terminalCheckFunc = term.IsTerminal
// IsTerminal reports whether os.Stdin is attached to a terminal.
func IsTerminal() bool {
return IsTerminalFile(os.Stdin)
}
// IsTerminalFD reports whether fd is attached to a terminal.
func IsTerminalFD(fd int) bool {
return terminalCheckFunc(fd)
}
// IsTerminalFile reports whether file is attached to a terminal.
func IsTerminalFile(file *os.File) bool {
if file == nil {
return false
}
return IsTerminalFD(int(file.Fd()))
}
// ReadPasswordContext reads one password-style line from the terminal.
//
// It returns ErrTerminalNotTTY when stdin is not a terminal. If ctx is canceled
// while waiting for input, this function returns ctx.Err() without waiting for
// the underlying terminal read to finish.
func ReadPasswordContext(ctx context.Context, hint string, defaultVal string) (string, error) {
return ReadPasswordContextWithMask(ctx, hint, defaultVal, "")
}
// ReadPasswordContextWithMask is like ReadPasswordContext but echoes the given
// mask string while the user types.
func ReadPasswordContextWithMask(ctx context.Context, hint string, defaultVal string, mask string) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return "", err
}
if !IsTerminal() {
return "", ErrTerminalNotTTY
}
session, err := rawTerminalSessionFactory(hint, true)
if err != nil {
return "", err
}
resultCh := make(chan InputMsg, 1)
go func() {
defer session.Close()
resultCh <- rawLineInputSession(session, defaultVal, mask, rawInputSignalReturnError)
}()
select {
case result := <-resultCh:
return result.String()
case <-ctx.Done():
session.Abort()
<-resultCh
return "", ctx.Err()
}
}

77
terminal_test.go Normal file
View File

@ -0,0 +1,77 @@
package stario
import (
"bufio"
"context"
"errors"
"io"
"strings"
"testing"
"time"
)
func installTerminalCheckStub(t *testing.T, ok bool) {
t.Helper()
prev := terminalCheckFunc
terminalCheckFunc = func(fd int) bool {
return ok
}
t.Cleanup(func() {
terminalCheckFunc = prev
})
}
func TestIsTerminalFDUsesStub(t *testing.T) {
installTerminalCheckStub(t, true)
if !IsTerminalFD(1) {
t.Fatal("expected terminal check to return true")
}
}
func TestReadPasswordContextRejectsNonTTY(t *testing.T) {
installTerminalCheckStub(t, false)
_, err := ReadPasswordContext(context.Background(), "pwd: ", "")
if !errors.Is(err, ErrTerminalNotTTY) {
t.Fatalf("expected ErrTerminalNotTTY, got %v", err)
}
}
func TestReadPasswordContextReadsValue(t *testing.T) {
installTerminalCheckStub(t, true)
installRawInputStub(t, "secret\n", nil)
got, err := ReadPasswordContext(context.Background(), "pwd: ", "")
if err != nil {
t.Fatalf("ReadPasswordContext returned error: %v", err)
}
if got != "secret" {
t.Fatalf("unexpected password: %q", got)
}
}
func TestReadPasswordContextCanceledWhileWaiting(t *testing.T) {
installTerminalCheckStub(t, true)
prevFactory := rawTerminalSessionFactory
pipeReader, pipeWriter := io.Pipe()
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
return &rawTerminalSession{
input: pipeReader,
reader: bufio.NewReader(pipeReader),
redrawHint: strings.TrimSpace(hint),
}, nil
}
t.Cleanup(func() {
rawTerminalSessionFactory = prevFactory
_ = pipeWriter.Close()
_ = pipeReader.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, err := ReadPasswordContext(ctx, "pwd: ", "")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
}

10
tty_other.go Normal file
View File

@ -0,0 +1,10 @@
//go:build !windows
// +build !windows
package stario
import "os"
func openRawTerminalInput() (*os.File, error) {
return os.OpenFile("/dev/tty", os.O_RDWR, 0)
}

10
tty_windows.go Normal file
View File

@ -0,0 +1,10 @@
//go:build windows
// +build windows
package stario
import "os"
func openRawTerminalInput() (*os.File, error) {
return os.OpenFile("CONIN$", os.O_RDWR, 0)
}