stario: 提升 Go 1.20 基线与交互/队列稳定性
- 提升 go.mod 基线到 Go 1.20,并补齐对应测试 - 修正 Passwd / PasswdResponseSignal 语义,Ctrl+C 默认退出当前流程 - 优化 raw terminal redraw、Restore 与 StopUntil 的边界行为 - 新增 StarPipe、FrameReader/FrameWriter、ReadFullContext/WriteFullContext/CopyContext、IsTerminal/ReadPasswordContext - 收口 StarQueue / StarBuffer 语义,删除 EndWrite,统一 Close / Abort 行为 - 补齐 signal、timeout、queue、terminal、pipe、buffer 的回归测试与 race 覆盖
This commit is contained in:
parent
3add9183b3
commit
c8facb5a03
43
circle.go
43
circle.go
@ -10,6 +10,10 @@ var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be grea
|
||||
var ErrStarBufferClosed = errors.New("star buffer closed")
|
||||
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
|
||||
|
||||
// StarBuffer is a blocking ring buffer that implements stream-style reads and writes.
|
||||
//
|
||||
// Close marks the write side finished after all payload bytes are sent.
|
||||
// Abort aborts both sides immediately but still allows buffered bytes to be drained.
|
||||
type StarBuffer struct {
|
||||
datas []byte
|
||||
pStart uint64
|
||||
@ -36,18 +40,21 @@ func NewStarBuffer(cap uint64) (*StarBuffer, error) {
|
||||
return rtnBuffer, nil
|
||||
}
|
||||
|
||||
// Free returns the remaining writable capacity.
|
||||
func (star *StarBuffer) Free() uint64 {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
return star.cap - star.size
|
||||
}
|
||||
|
||||
// Cap returns the fixed buffer capacity.
|
||||
func (star *StarBuffer) Cap() uint64 {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
return star.cap
|
||||
}
|
||||
|
||||
// Len returns the currently buffered byte count.
|
||||
func (star *StarBuffer) Len() uint64 {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
@ -89,19 +96,21 @@ func (star *StarBuffer) putByte(data byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) EndWrite() error {
|
||||
// Close closes only the write side and satisfies the usual io.Closer-style
|
||||
// "producer finished" semantics.
|
||||
//
|
||||
// Buffered bytes remain readable until drained; afterwards reads return io.EOF.
|
||||
func (star *StarBuffer) Close() error {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
if star.isClose {
|
||||
return ErrStarBufferClosed
|
||||
}
|
||||
star.isWriteEnd = true
|
||||
star.notEmpty.Broadcast()
|
||||
star.notFull.Broadcast()
|
||||
return nil
|
||||
return star.closeWriteLocked()
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Close() error {
|
||||
// Abort aborts the buffer and wakes blocked readers/writers immediately.
|
||||
//
|
||||
// Buffered bytes remain readable until drained; subsequent writes fail with
|
||||
// ErrStarBufferClosed.
|
||||
func (star *StarBuffer) Abort() error {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
star.isClose = true
|
||||
@ -111,10 +120,17 @@ func (star *StarBuffer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Read(buf []byte) (int, error) {
|
||||
if buf == nil {
|
||||
return 0, errors.New("buffer is nil")
|
||||
func (star *StarBuffer) closeWriteLocked() error {
|
||||
if star.isClose {
|
||||
return ErrStarBufferClosed
|
||||
}
|
||||
star.isWriteEnd = true
|
||||
star.notEmpty.Broadcast()
|
||||
star.notFull.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Read(buf []byte) (int, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@ -140,9 +156,6 @@ func (star *StarBuffer) Read(buf []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Write(bts []byte) (int, error) {
|
||||
if bts == nil {
|
||||
return 0, star.EndWrite()
|
||||
}
|
||||
if len(bts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
44
circle_benchmark_test.go
Normal file
44
circle_benchmark_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkStarBufferByteRoundTrip(b *testing.B) {
|
||||
buf, err := NewStarBuffer(4096)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := buf.putByte('a'); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := buf.getByte(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStarBufferChunkRoundTrip(b *testing.B) {
|
||||
buf, err := NewStarBuffer(8192)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
payload := []byte("hello world b612 hello world b612 b612 b612 b612 b612 b612")
|
||||
scratch := make([]byte, len(payload))
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := buf.Write(payload); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := io.ReadFull(buf, scratch); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
105
circle_test.go
105
circle_test.go
@ -2,11 +2,8 @@ package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
||||
@ -19,7 +16,7 @@ func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
||||
func TestStarBufferCloseDrainsThenEOF(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -27,7 +24,7 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
||||
if _, err := buf.Write([]byte("abcd")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := buf.EndWrite(); err != nil {
|
||||
if err := buf.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@ -46,11 +43,11 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
||||
}
|
||||
|
||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
|
||||
t.Fatalf("unexpected write error after EndWrite: %v", err)
|
||||
t.Fatalf("unexpected write error after Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
||||
func TestStarBufferAbortAllowsDrain(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -58,7 +55,7 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
||||
if _, err := buf.Write([]byte("ab")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := buf.Close(); err != nil {
|
||||
if err := buf.Abort(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@ -68,102 +65,38 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
|
||||
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n])
|
||||
t.Fatalf("unexpected payload after abort: n=%d data=%q", n, got[:n])
|
||||
}
|
||||
|
||||
n, err = buf.Read(got)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err)
|
||||
t.Fatalf("expected EOF after draining aborted buffer, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
|
||||
t.Fatalf("unexpected write error after Close: %v", err)
|
||||
t.Fatalf("unexpected write error after Abort: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Circle(t *testing.T) {
|
||||
buf, err := NewStarBuffer(2048)
|
||||
func TestStarBufferNilReadIsNoOp(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
//fmt.Println("write start")
|
||||
buf.Write([]byte("中华人民共和国\n"))
|
||||
//fmt.Println("write success")
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
}
|
||||
}()
|
||||
cpp := ""
|
||||
go func() {
|
||||
time.Sleep(time.Second * 3)
|
||||
for {
|
||||
cache := make([]byte, 64)
|
||||
ints, err := buf.Read(cache)
|
||||
if err != nil {
|
||||
fmt.Println("read error", err)
|
||||
return
|
||||
}
|
||||
if ints != 0 {
|
||||
cpp += string(cache[:ints])
|
||||
}
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Second * 13)
|
||||
fmt.Println(cpp)
|
||||
if n, err := buf.Read(nil); n != 0 || err != nil {
|
||||
t.Fatalf("expected nil read to be a no-op, got n=%d err=%v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Circle_Speed(t *testing.T) {
|
||||
buf, err := NewStarBuffer(1048976)
|
||||
func TestStarBufferNilWriteIsNoOp(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := uint64(0)
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
for {
|
||||
buf.putByte('a')
|
||||
}
|
||||
}()
|
||||
if n, err := buf.Write(nil); n != 0 || err != nil {
|
||||
t.Fatalf("expected nil write to be a no-op, got n=%d err=%v", n, err)
|
||||
}
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
for {
|
||||
_, err := buf.getByte()
|
||||
if err == nil {
|
||||
atomic.AddUint64(&count, 1)
|
||||
}
|
||||
}
|
||||
}()
|
||||
if _, err := buf.Write([]byte("ab")); err != nil {
|
||||
t.Fatalf("nil write must not end the write side, got %v", err)
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
fmt.Println(count)
|
||||
}
|
||||
|
||||
func Test_Circle_Speed2(t *testing.T) {
|
||||
buf, err := NewStarBuffer(8192)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := uint64(0)
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
for {
|
||||
buf.Write([]byte("hello world b612 hello world b612 b612 b612 b612 b612 b612"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
for {
|
||||
mybuf := make([]byte, 1024)
|
||||
j, err := buf.Read(mybuf)
|
||||
if err == nil {
|
||||
atomic.AddUint64(&count, uint64(j))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
fmt.Println(float64(count) / 10 / 1024 / 1024)
|
||||
}
|
||||
|
||||
126
fn.go
126
fn.go
@ -1,55 +1,111 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ERR_TIMEOUT is the legacy timeout sentinel used by WaitUntilTimeout*.
|
||||
var ERR_TIMEOUT = errors.New("TIME OUT")
|
||||
|
||||
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error {
|
||||
var err error
|
||||
finished := make(chan struct{})
|
||||
imout := make(chan struct{})
|
||||
// WaitUntilContext runs fn and returns either its result or the context error,
|
||||
// whichever happens first.
|
||||
func WaitUntilContext(ctx context.Context, fn func(context.Context) error) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
finished := make(chan error, 1)
|
||||
go func() {
|
||||
err = fn(imout)
|
||||
finished <- struct{}{}
|
||||
finished <- fn(ctx)
|
||||
}()
|
||||
select {
|
||||
case <-finished:
|
||||
case err := <-finished:
|
||||
return err
|
||||
case <-time.After(tm):
|
||||
close(imout)
|
||||
return ERR_TIMEOUT
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func WaitUntilFinished(fn func() error) <-chan error {
|
||||
finished := make(chan error)
|
||||
// WaitUntilContextFinished is the asynchronous form of WaitUntilContext.
|
||||
func WaitUntilContextFinished(ctx context.Context, fn func(context.Context) error) <-chan error {
|
||||
result := make(chan error, 1)
|
||||
go func() {
|
||||
err := fn()
|
||||
finished <- err
|
||||
}()
|
||||
return finished
|
||||
}
|
||||
|
||||
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
|
||||
var err error
|
||||
finished := make(chan struct{})
|
||||
result := make(chan error)
|
||||
imout := make(chan struct{})
|
||||
go func() {
|
||||
err = fn(imout)
|
||||
finished <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
result <- err
|
||||
case <-time.After(tm):
|
||||
close(imout)
|
||||
result <- ERR_TIMEOUT
|
||||
}
|
||||
result <- WaitUntilContext(ctx, fn)
|
||||
close(result)
|
||||
}()
|
||||
return result
|
||||
}
|
||||
|
||||
// WaitUntilContextDone adapts a done-channel worker to a context-based wait.
|
||||
func WaitUntilContextDone(ctx context.Context, fn func(<-chan struct{}) error) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return WaitUntilContext(ctx, func(context.Context) error {
|
||||
return fn(ctx.Done())
|
||||
})
|
||||
}
|
||||
|
||||
// WaitUntilContextDoneFinished is the asynchronous form of WaitUntilContextDone.
|
||||
func WaitUntilContextDoneFinished(ctx context.Context, fn func(<-chan struct{}) error) <-chan error {
|
||||
return WaitUntilContextFinished(ctx, func(ctx context.Context) error {
|
||||
return fn(ctx.Done())
|
||||
})
|
||||
}
|
||||
|
||||
// WaitUntilTimeout is a legacy timeout helper kept for compatibility.
|
||||
//
|
||||
// The provided stop channel must be treated as receive-only by callers. New
|
||||
// code should prefer WaitUntilContext or WaitUntilContextDone.
|
||||
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tm)
|
||||
defer cancel()
|
||||
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||
return fn(bridgeDoneChan(done))
|
||||
})
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ERR_TIMEOUT
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// WaitUntilFinished runs fn asynchronously and returns its eventual result.
|
||||
func WaitUntilFinished(fn func() error) <-chan error {
|
||||
return WaitUntilContextFinished(context.Background(), func(ctx context.Context) error {
|
||||
return fn()
|
||||
})
|
||||
}
|
||||
|
||||
// WaitUntilTimeoutFinished is the asynchronous form of WaitUntilTimeout.
|
||||
//
|
||||
// The provided stop channel must be treated as receive-only by callers. New
|
||||
// code should prefer WaitUntilContextFinished or WaitUntilContextDoneFinished.
|
||||
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
|
||||
result := make(chan error, 1)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tm)
|
||||
defer cancel()
|
||||
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||
return fn(bridgeDoneChan(done))
|
||||
})
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
err = ERR_TIMEOUT
|
||||
}
|
||||
result <- err
|
||||
close(result)
|
||||
}()
|
||||
return result
|
||||
}
|
||||
|
||||
func bridgeDoneChan(done <-chan struct{}) chan struct{} {
|
||||
stop := make(chan struct{})
|
||||
if done == nil {
|
||||
return stop
|
||||
}
|
||||
go func() {
|
||||
<-done
|
||||
close(stop)
|
||||
}()
|
||||
return stop
|
||||
}
|
||||
|
||||
174
fn_test.go
Normal file
174
fn_test.go
Normal file
@ -0,0 +1,174 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWaitUntilContextReturnsWorkerError(t *testing.T) {
|
||||
want := errors.New("worker failed")
|
||||
err := WaitUntilContext(context.Background(), func(ctx context.Context) error {
|
||||
return want
|
||||
})
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilContextReturnsDeadlineExceeded(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
workerReturned := make(chan struct{})
|
||||
err := WaitUntilContext(ctx, func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
close(workerReturned)
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("unexpected context error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-workerReturned:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("worker did not return after context deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilContextFinishedReturnsCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
result := WaitUntilContextFinished(ctx, func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
err, ok := <-result
|
||||
if !ok {
|
||||
t.Fatal("result channel closed without value")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("unexpected context error: %v", err)
|
||||
}
|
||||
if _, ok := <-result; ok {
|
||||
t.Fatal("result channel should be closed after delivering the outcome")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilContextDoneBridgesDoneChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
doneClosed := make(chan struct{})
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||
<-done
|
||||
close(doneClosed)
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("unexpected context error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-doneClosed:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("done channel was not closed when context canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilContextDoneFinishedReturnsCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
doneClosed := make(chan struct{})
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
result := WaitUntilContextDoneFinished(ctx, func(done <-chan struct{}) error {
|
||||
<-done
|
||||
close(doneClosed)
|
||||
return nil
|
||||
})
|
||||
err, ok := <-result
|
||||
if !ok {
|
||||
t.Fatal("result channel closed without value")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("unexpected context error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-doneClosed:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("done channel was not closed when context canceled")
|
||||
}
|
||||
if _, ok := <-result; ok {
|
||||
t.Fatal("result channel should be closed after delivering the outcome")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilTimeoutReturnsWorkerError(t *testing.T) {
|
||||
want := errors.New("worker failed")
|
||||
err := WaitUntilTimeout(time.Second, func(stop chan struct{}) error {
|
||||
return want
|
||||
})
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilTimeoutTimesOutWithoutBlockingWorkerReturn(t *testing.T) {
|
||||
workerReturned := make(chan struct{})
|
||||
err := WaitUntilTimeout(20*time.Millisecond, func(stop chan struct{}) error {
|
||||
<-stop
|
||||
close(workerReturned)
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, ERR_TIMEOUT) {
|
||||
t.Fatalf("unexpected timeout error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-workerReturned:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("worker did not return after timeout signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilFinishedReturnsWorkerError(t *testing.T) {
|
||||
want := errors.New("worker failed")
|
||||
err := <-WaitUntilFinished(func() error {
|
||||
return want
|
||||
})
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilTimeoutFinishedReturnsTimeout(t *testing.T) {
|
||||
result := WaitUntilTimeoutFinished(20*time.Millisecond, func(stop chan struct{}) error {
|
||||
<-stop
|
||||
return nil
|
||||
})
|
||||
err, ok := <-result
|
||||
if !ok {
|
||||
t.Fatal("result channel closed without value")
|
||||
}
|
||||
if !errors.Is(err, ERR_TIMEOUT) {
|
||||
t.Fatalf("unexpected timeout error: %v", err)
|
||||
}
|
||||
if _, ok := <-result; ok {
|
||||
t.Fatal("result channel should be closed after delivering the outcome")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitUntilTimeoutFinishedReturnsWorkerError(t *testing.T) {
|
||||
want := errors.New("worker failed")
|
||||
err, ok := <-WaitUntilTimeoutFinished(time.Second, func(stop chan struct{}) error {
|
||||
return want
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("result channel closed without value")
|
||||
}
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||
}
|
||||
}
|
||||
191
frameio.go
Normal file
191
frameio.go
Normal file
@ -0,0 +1,191 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
|
||||
// FrameReader.
|
||||
const DefaultFrameReaderBufferSize = 32 * 1024
|
||||
|
||||
type frameReaderConnKey struct{}
|
||||
|
||||
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
|
||||
type FrameWriter struct {
|
||||
writer io.Writer
|
||||
queue *StarQueue
|
||||
}
|
||||
|
||||
// NewFrameWriter creates a framing writer backed by queue. When queue is nil, a
|
||||
// default StarQueue is created.
|
||||
func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter {
|
||||
if queue == nil {
|
||||
queue = NewQueue()
|
||||
}
|
||||
return &FrameWriter{
|
||||
writer: writer,
|
||||
queue: queue,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteFrame writes one framed payload.
|
||||
func (writer *FrameWriter) WriteFrame(payload []byte) error {
|
||||
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return writer.queue.WriteFrame(writer.writer, payload)
|
||||
}
|
||||
|
||||
// WriteFrameBuffers writes one framed payload using net.Buffers when possible.
|
||||
func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error {
|
||||
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return writer.queue.WriteFrameBuffers(writer.writer, payload)
|
||||
}
|
||||
|
||||
// WriteFramesBuffers writes multiple framed payloads in one batch when
|
||||
// possible.
|
||||
func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
|
||||
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return writer.queue.WriteFramesBuffers(writer.writer, payloads...)
|
||||
}
|
||||
|
||||
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
|
||||
type FrameReader struct {
|
||||
reader io.Reader
|
||||
queue *StarQueue
|
||||
connKey interface{}
|
||||
readSize int
|
||||
pending [][]byte
|
||||
pendingErr error
|
||||
}
|
||||
|
||||
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
|
||||
// default StarQueue is created.
|
||||
func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader {
|
||||
if queue == nil {
|
||||
queue = NewQueue()
|
||||
}
|
||||
return &FrameReader{
|
||||
reader: reader,
|
||||
queue: queue,
|
||||
connKey: &frameReaderConnKey{},
|
||||
readSize: DefaultFrameReaderBufferSize,
|
||||
}
|
||||
}
|
||||
|
||||
// SetReadBufferSize updates the underlying transport read chunk size.
|
||||
func (reader *FrameReader) SetReadBufferSize(size int) {
|
||||
if reader == nil || size <= 0 {
|
||||
return
|
||||
}
|
||||
reader.readSize = size
|
||||
}
|
||||
|
||||
// SetConnKey overrides the internal queue connection key.
|
||||
func (reader *FrameReader) SetConnKey(conn interface{}) error {
|
||||
if reader == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
if conn == nil {
|
||||
reader.connKey = &frameReaderConnKey{}
|
||||
return nil
|
||||
}
|
||||
if err := validateConnKey(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
reader.connKey = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// Next returns the next framed payload.
|
||||
func (reader *FrameReader) Next() ([]byte, error) {
|
||||
if reader == nil || reader.reader == nil || reader.queue == nil {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
if len(reader.pending) > 0 {
|
||||
return reader.popPending(), nil
|
||||
}
|
||||
if reader.pendingErr != nil {
|
||||
err := reader.pendingErr
|
||||
reader.pendingErr = nil
|
||||
return nil, err
|
||||
}
|
||||
if reader.readSize <= 0 {
|
||||
reader.readSize = DefaultFrameReaderBufferSize
|
||||
}
|
||||
buf := make([]byte, reader.readSize)
|
||||
for {
|
||||
n, readErr := reader.reader.Read(buf)
|
||||
if n > 0 {
|
||||
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
|
||||
reader.pending = append(reader.pending, msg.Msg)
|
||||
return nil
|
||||
})
|
||||
err := reader.normalizeReadErr(parseErr, readErr)
|
||||
if len(reader.pending) > 0 {
|
||||
reader.pendingErr = err
|
||||
return reader.popPending(), nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return nil, reader.normalizeReadErr(nil, readErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (reader *FrameReader) popPending() []byte {
|
||||
next := reader.pending[0]
|
||||
reader.pending = reader.pending[1:]
|
||||
return next
|
||||
}
|
||||
|
||||
func joinFrameReaderError(parseErr error, readErr error) error {
|
||||
switch {
|
||||
case parseErr == nil:
|
||||
return readErr
|
||||
case readErr == nil:
|
||||
return parseErr
|
||||
default:
|
||||
return errors.Join(parseErr, readErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error {
|
||||
if errors.Is(readErr, io.EOF) && reader.hasBufferedState() {
|
||||
reader.clearBufferedState()
|
||||
readErr = io.ErrUnexpectedEOF
|
||||
}
|
||||
return joinFrameReaderError(parseErr, readErr)
|
||||
}
|
||||
|
||||
func (reader *FrameReader) hasBufferedState() bool {
|
||||
if reader == nil || reader.queue == nil {
|
||||
return false
|
||||
}
|
||||
stateAny, ok := reader.queue.states.Load(reader.connKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
state, ok := stateAny.(*queConnState)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
state.mu.Lock()
|
||||
defer state.mu.Unlock()
|
||||
return len(state.buf) > 0
|
||||
}
|
||||
|
||||
func (reader *FrameReader) clearBufferedState() {
|
||||
if reader == nil || reader.queue == nil {
|
||||
return
|
||||
}
|
||||
reader.queue.states.Delete(reader.connKey)
|
||||
}
|
||||
105
frameio_test.go
Normal file
105
frameio_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type chunkedReader struct {
|
||||
data []byte
|
||||
max int
|
||||
}
|
||||
|
||||
func (reader *chunkedReader) Read(p []byte) (int, error) {
|
||||
if len(reader.data) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if reader.max > 0 && len(p) > reader.max {
|
||||
p = p[:reader.max]
|
||||
}
|
||||
n := copy(p, reader.data)
|
||||
reader.data = reader.data[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func TestFrameWriterReaderRoundTrip(t *testing.T) {
|
||||
var wire bytes.Buffer
|
||||
writer := NewFrameWriter(&wire, nil)
|
||||
if err := writer.WriteFrame([]byte("one")); err != nil {
|
||||
t.Fatalf("WriteFrame failed: %v", err)
|
||||
}
|
||||
if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil {
|
||||
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||
}
|
||||
|
||||
reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil)
|
||||
cases := []string{"one", "two", "three"}
|
||||
for _, want := range cases {
|
||||
got, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("Next returned error: %v", err)
|
||||
}
|
||||
if string(got) != want {
|
||||
t.Fatalf("unexpected frame: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
if _, err := reader.Next(); !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("expected EOF after draining frames, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
|
||||
var wire bytes.Buffer
|
||||
writer := NewFrameWriter(&wire, nil)
|
||||
if err := writer.WriteFrame([]byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrame failed: %v", err)
|
||||
}
|
||||
|
||||
reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil)
|
||||
reader.SetReadBufferSize(3)
|
||||
got, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("Next returned error: %v", err)
|
||||
}
|
||||
if string(got) != "hello" {
|
||||
t.Fatalf("unexpected frame: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameWriterNilWriterFails(t *testing.T) {
|
||||
writer := NewFrameWriter(nil, nil)
|
||||
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) {
|
||||
queue := NewQueue()
|
||||
frame := queue.BuildMessage([]byte("hello"))
|
||||
reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil)
|
||||
|
||||
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) {
|
||||
queue := NewQueue()
|
||||
first := queue.BuildMessage([]byte("one"))
|
||||
second := queue.BuildMessage([]byte("two"))
|
||||
wire := append(append([]byte{}, first...), second[:len(second)-1]...)
|
||||
reader := NewFrameReader(bytes.NewReader(wire), nil)
|
||||
|
||||
got, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on first frame: %v", err)
|
||||
}
|
||||
if string(got) != "one" {
|
||||
t.Fatalf("unexpected first frame: %q", got)
|
||||
}
|
||||
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err)
|
||||
}
|
||||
}
|
||||
6
go.mod
6
go.mod
@ -1,5 +1,7 @@
|
||||
module b612.me/stario
|
||||
|
||||
go 1.16
|
||||
go 1.20
|
||||
|
||||
require golang.org/x/crypto v0.26.0
|
||||
require golang.org/x/term v0.23.0
|
||||
|
||||
require golang.org/x/sys v0.23.0 // indirect
|
||||
|
||||
63
go.sum
63
go.sum
@ -1,67 +1,4 @@
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
|
||||
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
|
||||
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
239
io.go
239
io.go
@ -3,11 +3,12 @@ package stario
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"golang.org/x/term"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type InputMsg struct {
|
||||
@ -16,30 +17,54 @@ type InputMsg struct {
|
||||
skipSliceSigErr bool
|
||||
}
|
||||
|
||||
type rawInputSignalMode uint8
|
||||
|
||||
const (
|
||||
rawInputSignalIgnore rawInputSignalMode = iota
|
||||
rawInputSignalExit
|
||||
rawInputSignalReturnError
|
||||
)
|
||||
|
||||
type rawTerminalSession struct {
|
||||
fd int
|
||||
state *terminal.State
|
||||
state *term.State
|
||||
reader *bufio.Reader
|
||||
input io.Closer
|
||||
redrawHint string
|
||||
printNewline bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var rawTerminalSessionFactory = newRawTerminalSession
|
||||
var inputSignalHandler = signal
|
||||
|
||||
// Passwd reads one password-style line in raw mode.
|
||||
//
|
||||
// When the user presses an input signal such as Ctrl+C, this compatibility
|
||||
// entry exits the current flow and returns an empty message with a nil error.
|
||||
func Passwd(hint string, defaultVal string) InputMsg {
|
||||
return passwd(hint, defaultVal, "", true)
|
||||
return passwd(hint, defaultVal, "", rawInputSignalExit)
|
||||
}
|
||||
|
||||
// PasswdWithMask is like Passwd but echoes the provided mask string.
|
||||
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
|
||||
return passwd(hint, defaultVal, mask, true)
|
||||
return passwd(hint, defaultVal, mask, rawInputSignalExit)
|
||||
}
|
||||
|
||||
// PasswdResponseSignal is like Passwd but preserves input-signal errors for
|
||||
// callers that need to distinguish Ctrl+C / Ctrl+Z style exits.
|
||||
func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
|
||||
return passwd(hint, defaultVal, "", true)
|
||||
return passwd(hint, defaultVal, "", rawInputSignalReturnError)
|
||||
}
|
||||
|
||||
// PasswdResponseSignalWithMask is like PasswdResponseSignal but echoes the
|
||||
// provided mask string.
|
||||
func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg {
|
||||
return passwd(hint, defaultVal, mask, true)
|
||||
return passwd(hint, defaultVal, mask, rawInputSignalReturnError)
|
||||
}
|
||||
|
||||
// MessageBoxRaw reads one line in raw mode without treating control keys as
|
||||
// exit signals.
|
||||
func MessageBoxRaw(hint string, defaultVal string) InputMsg {
|
||||
return messageBox(hint, defaultVal)
|
||||
}
|
||||
@ -48,35 +73,64 @@ func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession,
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
fd := int(os.Stdin.Fd())
|
||||
state, err := terminal.MakeRaw(fd)
|
||||
input, err := openRawTerminalInput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fd := int(input.Fd())
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
_ = input.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &rawTerminalSession{
|
||||
fd: fd,
|
||||
state: state,
|
||||
reader: bufio.NewReader(os.Stdin),
|
||||
reader: bufio.NewReader(input),
|
||||
input: input,
|
||||
redrawHint: promptRedrawHint(hint),
|
||||
printNewline: printNewline,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (session *rawTerminalSession) Close() {
|
||||
if session == nil || session.state == nil {
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
_ = terminal.Restore(session.fd, session.state)
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
if session.state != nil {
|
||||
_ = term.Restore(session.fd, session.state)
|
||||
session.state = nil
|
||||
}
|
||||
if session.printNewline {
|
||||
fmt.Println()
|
||||
session.printNewline = false
|
||||
}
|
||||
if session.input != nil {
|
||||
_ = session.input.Close()
|
||||
session.input = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (session *rawTerminalSession) Restore() error {
|
||||
if session == nil || session.state == nil {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
return terminal.Restore(session.fd, session.state)
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
if session.state == nil {
|
||||
return nil
|
||||
}
|
||||
if err := term.Restore(session.fd, session.state); err != nil {
|
||||
return err
|
||||
}
|
||||
session.state = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *rawTerminalSession) Abort() {
|
||||
session.Close()
|
||||
}
|
||||
|
||||
func promptRedrawHint(hint string) string {
|
||||
@ -101,6 +155,20 @@ func renderRawEcho(ioBuf []rune, mask string) string {
|
||||
return strings.Repeat(mask, len(ioBuf))
|
||||
}
|
||||
|
||||
func rawEchoRenderUnit(r rune, mask string, maskWidth int) (string, int, bool) {
|
||||
if mask != "" {
|
||||
if maskWidth <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
return mask, maskWidth, true
|
||||
}
|
||||
width := runeDisplayWidth(r)
|
||||
if width <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
return string(r), width, true
|
||||
}
|
||||
|
||||
func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
||||
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
|
||||
clearWidth := lastWidth
|
||||
@ -121,6 +189,22 @@ func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
||||
return nowWidth
|
||||
}
|
||||
|
||||
func erasePromptTail(width int) {
|
||||
if width <= 0 {
|
||||
return
|
||||
}
|
||||
backtrack := strings.Repeat("\b", width)
|
||||
fmt.Print(backtrack)
|
||||
fmt.Print(strings.Repeat(" ", width))
|
||||
fmt.Print(backtrack)
|
||||
}
|
||||
|
||||
func redrawPromptEcho(hint string, ioBuf []rune, mask string, lastWidth int) (int, int) {
|
||||
echo := renderRawEcho(ioBuf, mask)
|
||||
echoWidth := stringDisplayWidth(echo)
|
||||
return echoWidth, redrawPromptLine(hint, echo, lastWidth)
|
||||
}
|
||||
|
||||
func stringDisplayWidth(text string) int {
|
||||
width := 0
|
||||
for _, r := range text {
|
||||
@ -157,46 +241,79 @@ func isWideRune(r rune) bool {
|
||||
(r >= 0x20000 && r <= 0x3fffd))
|
||||
}
|
||||
|
||||
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
||||
session, err := newRawTerminalSession(hint, true)
|
||||
func signalInputResult(mode rawInputSignalMode, err error) InputMsg {
|
||||
switch mode {
|
||||
case rawInputSignalExit:
|
||||
return InputMsg{msg: "", err: nil}
|
||||
case rawInputSignalReturnError:
|
||||
return InputMsg{msg: "", err: err}
|
||||
default:
|
||||
return InputMsg{msg: "", err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
func rawLineInput(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||
session, err := rawTerminalSessionFactory(hint, true)
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
defer session.Close()
|
||||
return rawLineInputSession(session, defaultVal, mask, signalMode)
|
||||
}
|
||||
|
||||
func rawLineInputSession(session *rawTerminalSession, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||
if session == nil || session.reader == nil {
|
||||
return InputMsg{msg: "", err: io.ErrClosedPipe}
|
||||
}
|
||||
ioBuf := make([]rune, 0, 16)
|
||||
lastWidth := 0
|
||||
promptWidth := stringDisplayWidth(session.redrawHint)
|
||||
maskWidth := stringDisplayWidth(mask)
|
||||
echoWidth := 0
|
||||
lastWidth := promptWidth
|
||||
for {
|
||||
b, _, err := session.reader.ReadRune()
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
if handleSignal && isSignal(b) {
|
||||
if runtime.GOOS != "windows" {
|
||||
if err := session.Restore(); err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
if signalMode != rawInputSignalIgnore && isSignal(b) {
|
||||
session.Close()
|
||||
if signalMode == rawInputSignalExit {
|
||||
return signalInputResult(signalMode, nil)
|
||||
}
|
||||
if err := signal(b); err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
continue
|
||||
return signalInputResult(signalMode, inputSignalHandler(b))
|
||||
}
|
||||
switch b {
|
||||
case 0x0d, 0x0a:
|
||||
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
|
||||
case 0x08, 0x7F:
|
||||
if len(ioBuf) > 0 {
|
||||
removed := ioBuf[len(ioBuf)-1]
|
||||
ioBuf = ioBuf[:len(ioBuf)-1]
|
||||
if _, removedWidth, ok := rawEchoRenderUnit(removed, mask, maskWidth); ok {
|
||||
erasePromptTail(removedWidth)
|
||||
echoWidth -= removedWidth
|
||||
if echoWidth < 0 {
|
||||
echoWidth = 0
|
||||
}
|
||||
lastWidth = promptWidth + echoWidth
|
||||
continue
|
||||
}
|
||||
}
|
||||
default:
|
||||
ioBuf = append(ioBuf, b)
|
||||
if appendText, appendWidth, ok := rawEchoRenderUnit(b, mask, maskWidth); ok {
|
||||
fmt.Print(appendText)
|
||||
echoWidth += appendWidth
|
||||
lastWidth = promptWidth + echoWidth
|
||||
continue
|
||||
}
|
||||
}
|
||||
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth)
|
||||
echoWidth, lastWidth = redrawPromptEcho(session.redrawHint, ioBuf, mask, lastWidth)
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(hint string, defaultVal string) InputMsg {
|
||||
return rawLineInput(hint, defaultVal, "", false)
|
||||
return rawLineInput(hint, defaultVal, "", rawInputSignalIgnore)
|
||||
}
|
||||
|
||||
func isSignal(s rune) bool {
|
||||
@ -208,10 +325,12 @@ func isSignal(s rune) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
||||
return rawLineInput(hint, defaultVal, mask, handleSignal)
|
||||
func passwd(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||
return rawLineInput(hint, defaultVal, mask, signalMode)
|
||||
}
|
||||
|
||||
// MessageBox reads one line in cooked mode and falls back to defaultVal when
|
||||
// the trimmed input is empty.
|
||||
func MessageBox(hint string, defaultVal string) InputMsg {
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
@ -264,7 +383,7 @@ func (im InputMsg) sliceFn(sep string, fn func(string) (interface{}, error)) ([]
|
||||
return res, err
|
||||
}
|
||||
for _, v := range data {
|
||||
code, err := fn(v)
|
||||
code, err := fn(strings.TrimSpace(v))
|
||||
if err != nil && !im.skipSliceSigErr {
|
||||
return nil, err
|
||||
} else if err == nil {
|
||||
@ -428,7 +547,8 @@ func (im InputMsg) MustFloat32() float32 {
|
||||
|
||||
func (im InputMsg) SliceFloat32(sep string) ([]float32, error) {
|
||||
data, err := im.sliceFn(sep, func(v string) (interface{}, error) {
|
||||
return strconv.ParseFloat(v, 32)
|
||||
f, err := strconv.ParseFloat(v, 32)
|
||||
return float32(f), err
|
||||
})
|
||||
var res []float32
|
||||
for _, v := range data {
|
||||
@ -477,13 +597,47 @@ func YesNoE(hint string, defaults bool) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildTriggerPrefixTable(triggerRunes []rune) []int {
|
||||
if len(triggerRunes) == 0 {
|
||||
return nil
|
||||
}
|
||||
prefix := make([]int, len(triggerRunes))
|
||||
for i := 1; i < len(triggerRunes); i++ {
|
||||
j := prefix[i-1]
|
||||
for j > 0 && triggerRunes[i] != triggerRunes[j] {
|
||||
j = prefix[j-1]
|
||||
}
|
||||
if triggerRunes[i] == triggerRunes[j] {
|
||||
j++
|
||||
}
|
||||
prefix[i] = j
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func advanceTriggerIndex(triggerRunes []rune, prefix []int, current int, input rune) (int, bool) {
|
||||
if len(triggerRunes) == 0 {
|
||||
return 0, true
|
||||
}
|
||||
for current > 0 && input != triggerRunes[current] {
|
||||
current = prefix[current-1]
|
||||
}
|
||||
if input == triggerRunes[current] {
|
||||
current++
|
||||
if current == len(triggerRunes) {
|
||||
return current, true
|
||||
}
|
||||
}
|
||||
return current, false
|
||||
}
|
||||
|
||||
// StopUntil keeps reading raw input until trigger is matched.
|
||||
// When trigger == "", it returns after the first key press, which is used for
|
||||
// "press any key to continue" style prompts.
|
||||
func StopUntil(hint string, trigger string, repeat bool) error {
|
||||
triggerRunes := []rune(trigger)
|
||||
pressLen := len(triggerRunes)
|
||||
if trigger == "" {
|
||||
pressLen = 1
|
||||
}
|
||||
session, err := newRawTerminalSession(hint, false)
|
||||
prefix := buildTriggerPrefixTable(triggerRunes)
|
||||
session, err := rawTerminalSessionFactory(hint, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -497,11 +651,12 @@ func StopUntil(hint string, trigger string, repeat bool) error {
|
||||
if trigger == "" {
|
||||
break
|
||||
}
|
||||
if b == triggerRunes[i] {
|
||||
i++
|
||||
if i == pressLen {
|
||||
break
|
||||
}
|
||||
next, complete := advanceTriggerIndex(triggerRunes, prefix, i, b)
|
||||
if complete {
|
||||
break
|
||||
}
|
||||
if next > 0 {
|
||||
i = next
|
||||
continue
|
||||
}
|
||||
i = 0
|
||||
|
||||
81
io_benchmark_test.go
Normal file
81
io_benchmark_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func benchmarkRawRedrawAppend(b *testing.B, current []rune, next rune, mask string) {
|
||||
b.Run("legacy", func(b *testing.B) {
|
||||
ioBuf := make([]rune, len(current)+1)
|
||||
copy(ioBuf, current)
|
||||
ioBuf[len(current)] = next
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
echo := renderRawEcho(ioBuf, mask)
|
||||
_ = stringDisplayWidth(echo)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("fast", func(b *testing.B) {
|
||||
maskWidth := stringDisplayWidth(mask)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, ok := rawEchoRenderUnit(next, mask, maskWidth)
|
||||
if !ok {
|
||||
b.Fatal("fast path rejected append rune")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkRawRedrawBackspace(b *testing.B, current []rune, mask string) {
|
||||
if len(current) == 0 {
|
||||
b.Fatal("backspace benchmark requires non-empty input")
|
||||
}
|
||||
last := current[len(current)-1]
|
||||
trimmed := current[:len(current)-1]
|
||||
|
||||
b.Run("legacy", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
echo := renderRawEcho(trimmed, mask)
|
||||
_ = stringDisplayWidth(echo)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("fast", func(b *testing.B) {
|
||||
maskWidth := stringDisplayWidth(mask)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, ok := rawEchoRenderUnit(last, mask, maskWidth)
|
||||
if !ok {
|
||||
b.Fatal("fast path rejected backspace rune")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRawRedrawAppendPlainLong(b *testing.B) {
|
||||
current := []rune(strings.Repeat("a", 4096))
|
||||
benchmarkRawRedrawAppend(b, current, 'b', "")
|
||||
}
|
||||
|
||||
func BenchmarkRawRedrawAppendMaskLong(b *testing.B) {
|
||||
current := []rune(strings.Repeat("a", 4096))
|
||||
benchmarkRawRedrawAppend(b, current, 'b', "[]")
|
||||
}
|
||||
|
||||
func BenchmarkRawRedrawBackspacePlainLong(b *testing.B) {
|
||||
current := []rune(strings.Repeat("a", 4096))
|
||||
benchmarkRawRedrawBackspace(b, current, "")
|
||||
}
|
||||
|
||||
func BenchmarkRawRedrawBackspaceMaskLong(b *testing.B) {
|
||||
current := []rune(strings.Repeat("a", 4096))
|
||||
benchmarkRawRedrawBackspace(b, current, "[]")
|
||||
}
|
||||
177
io_context.go
Normal file
177
io_context.go
Normal file
@ -0,0 +1,177 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
const defaultCopyContextBufferSize = 32 * 1024
|
||||
|
||||
type readContextResult struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type writeContextResult struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// ReadFullContext reads exactly len(buf) bytes unless the context is canceled
|
||||
// or the underlying reader returns an error.
|
||||
//
|
||||
// If ctx is canceled while the underlying Read call is blocked, ReadFullContext
|
||||
// returns ctx.Err() without waiting for that call to finish. The underlying
|
||||
// reader may still complete asynchronously afterwards.
|
||||
func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) {
|
||||
if reader == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := 0
|
||||
for total < len(buf) {
|
||||
chunk, err := readContext(ctx, reader, len(buf)-total)
|
||||
if len(chunk) > 0 {
|
||||
total += copy(buf[total:], chunk)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
if total > 0 {
|
||||
return total, io.ErrUnexpectedEOF
|
||||
}
|
||||
return total, io.EOF
|
||||
}
|
||||
return total, err
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
return total, io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// WriteFullContext writes the full payload unless the context is canceled or
|
||||
// the underlying writer returns an error.
|
||||
//
|
||||
// If ctx is canceled while the underlying Write call is blocked,
|
||||
// WriteFullContext returns ctx.Err() without waiting for that call to finish.
|
||||
// The underlying writer may still complete asynchronously afterwards.
|
||||
func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||||
if writer == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := 0
|
||||
for total < len(data) {
|
||||
written, err := writeContext(ctx, writer, data[total:])
|
||||
if written > 0 {
|
||||
total += written
|
||||
}
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if written == 0 {
|
||||
return total, io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// CopyContext copies from src to dst until EOF, context cancellation, or a
|
||||
// non-EOF error occurs.
|
||||
//
|
||||
// If ctx is canceled while the current read or write is blocked, CopyContext
|
||||
// returns ctx.Err() without waiting for that operation to finish.
|
||||
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
if dst == nil || src == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var copied int64
|
||||
for {
|
||||
chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize)
|
||||
if len(chunk) > 0 {
|
||||
written, writeErr := WriteFullContext(ctx, dst, chunk)
|
||||
copied += int64(written)
|
||||
if writeErr != nil {
|
||||
return copied, writeErr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return copied, nil
|
||||
}
|
||||
return copied, readErr
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
return copied, io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) {
|
||||
if size <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
resultCh := make(chan readContextResult, 1)
|
||||
go func() {
|
||||
tmp := make([]byte, size)
|
||||
n, err := reader.Read(tmp)
|
||||
if n > 0 {
|
||||
tmp = tmp[:n]
|
||||
} else {
|
||||
tmp = nil
|
||||
}
|
||||
resultCh <- readContextResult{data: tmp, err: err}
|
||||
}()
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.data, result.err
|
||||
case <-ctx.Done():
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.data, result.err
|
||||
default:
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||||
if len(data) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
payload := append([]byte(nil), data...)
|
||||
resultCh := make(chan writeContextResult, 1)
|
||||
go func() {
|
||||
n, err := writer.Write(payload)
|
||||
resultCh <- writeContextResult{n: n, err: err}
|
||||
}()
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.n, result.err
|
||||
case <-ctx.Done():
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.n, result.err
|
||||
default:
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
153
io_context_test.go
Normal file
153
io_context_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type chunkedWriter struct {
|
||||
buf bytes.Buffer
|
||||
max int
|
||||
}
|
||||
|
||||
func (writer *chunkedWriter) Write(p []byte) (int, error) {
|
||||
if writer.max <= 0 || len(p) <= writer.max {
|
||||
return writer.buf.Write(p)
|
||||
}
|
||||
return writer.buf.Write(p[:writer.max])
|
||||
}
|
||||
|
||||
type blockingReader struct {
|
||||
started chan struct{}
|
||||
release chan struct{}
|
||||
data []byte
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (reader *blockingReader) Read(p []byte) (int, error) {
|
||||
reader.once.Do(func() { close(reader.started) })
|
||||
<-reader.release
|
||||
n := copy(p, reader.data)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type blockingWriter struct {
|
||||
started chan struct{}
|
||||
release chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (writer *blockingWriter) Write(p []byte) (int, error) {
|
||||
writer.once.Do(func() { close(writer.started) })
|
||||
<-writer.release
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestReadFullContext(t *testing.T) {
|
||||
buf := make([]byte, 5)
|
||||
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFullContext returned error: %v", err)
|
||||
}
|
||||
if n != 5 || string(buf) != "hello" {
|
||||
t.Fatalf("unexpected payload: n=%d data=%q", n, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) {
|
||||
buf := make([]byte, 5)
|
||||
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf)
|
||||
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Fatalf("expected unexpected EOF, got %v", err)
|
||||
}
|
||||
if n != 3 || string(buf[:n]) != "hey" {
|
||||
t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFullContextCanceledWhileBlocked(t *testing.T) {
|
||||
reader := &blockingReader{
|
||||
started: make(chan struct{}),
|
||||
release: make(chan struct{}),
|
||||
data: []byte("hello"),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 5)
|
||||
_, err := ReadFullContext(ctx, reader, buf)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-reader.started
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("ReadFullContext did not return after cancel")
|
||||
}
|
||||
|
||||
close(reader.release)
|
||||
}
|
||||
|
||||
func TestWriteFullContext(t *testing.T) {
|
||||
writer := &chunkedWriter{max: 2}
|
||||
n, err := WriteFullContext(context.Background(), writer, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("WriteFullContext returned error: %v", err)
|
||||
}
|
||||
if n != 5 || writer.buf.String() != "hello" {
|
||||
t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFullContextCanceledWhileBlocked(t *testing.T) {
|
||||
writer := &blockingWriter{
|
||||
started: make(chan struct{}),
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := WriteFullContext(ctx, writer, []byte("hello"))
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-writer.started
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("WriteFullContext did not return after cancel")
|
||||
}
|
||||
|
||||
close(writer.release)
|
||||
}
|
||||
|
||||
func TestCopyContext(t *testing.T) {
|
||||
var dst bytes.Buffer
|
||||
written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world"))
|
||||
if err != nil {
|
||||
t.Fatalf("CopyContext returned error: %v", err)
|
||||
}
|
||||
if written != int64(len("hello world")) || dst.String() != "hello world" {
|
||||
t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String())
|
||||
}
|
||||
}
|
||||
150
io_test.go
150
io_test.go
@ -1,10 +1,32 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func installRawInputStub(t *testing.T, input string, signalErr error) {
|
||||
t.Helper()
|
||||
prevFactory := rawTerminalSessionFactory
|
||||
prevSignalHandler := inputSignalHandler
|
||||
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
|
||||
return &rawTerminalSession{
|
||||
reader: bufio.NewReader(strings.NewReader(input)),
|
||||
redrawHint: promptRedrawHint(hint),
|
||||
}, nil
|
||||
}
|
||||
inputSignalHandler = func(sigtype rune) error {
|
||||
return signalErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rawTerminalSessionFactory = prevFactory
|
||||
inputSignalHandler = prevSignalHandler
|
||||
})
|
||||
}
|
||||
|
||||
func TestPromptRedrawHint(t *testing.T) {
|
||||
got := promptRedrawHint("头部提示\n 中文确认: ")
|
||||
if got != "中文确认:" {
|
||||
@ -19,6 +41,81 @@ func TestStringDisplayWidth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawEchoRenderUnitPlainRune(t *testing.T) {
|
||||
text, width, ok := rawEchoRenderUnit('中', "", 0)
|
||||
if !ok {
|
||||
t.Fatal("expected plain wide rune to use fast path")
|
||||
}
|
||||
if text != "中" || width != 2 {
|
||||
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "中", 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawEchoRenderUnitMaskedRune(t *testing.T) {
|
||||
text, width, ok := rawEchoRenderUnit('a', "[]", stringDisplayWidth("[]"))
|
||||
if !ok {
|
||||
t.Fatal("expected masked rune to use fast path")
|
||||
}
|
||||
if text != "[]" || width != 2 {
|
||||
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "[]", 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawEchoRenderUnitFallsBackForControlRune(t *testing.T) {
|
||||
if _, _, ok := rawEchoRenderUnit('\x00', "", 0); ok {
|
||||
t.Fatal("expected control rune to fall back to full redraw")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalInputResultExitSuppressesError(t *testing.T) {
|
||||
got := signalInputResult(rawInputSignalExit, ErrSignalInterrupt)
|
||||
if got.err != nil {
|
||||
t.Fatalf("expected nil error, got %v", got.err)
|
||||
}
|
||||
if got.msg != "" {
|
||||
t.Fatalf("expected empty message, got %q", got.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalInputResultReturnErrorPreservesSignal(t *testing.T) {
|
||||
got := signalInputResult(rawInputSignalReturnError, ErrSignalInterrupt)
|
||||
if !errors.Is(got.err, ErrSignalInterrupt) {
|
||||
t.Fatalf("expected signal error, got %v", got.err)
|
||||
}
|
||||
if got.msg != "" {
|
||||
t.Fatalf("expected empty message, got %q", got.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswdSuppressesSignalError(t *testing.T) {
|
||||
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
|
||||
got := Passwd("", "fallback")
|
||||
if got.err != nil {
|
||||
t.Fatalf("expected nil error, got %v", got.err)
|
||||
}
|
||||
if got.msg != "" {
|
||||
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswdResponseSignalPreservesSignalError(t *testing.T) {
|
||||
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
|
||||
got := PasswdResponseSignal("", "fallback")
|
||||
if !errors.Is(got.err, ErrSignalInterrupt) {
|
||||
t.Fatalf("expected interrupt error, got %v", got.err)
|
||||
}
|
||||
if got.msg != "" {
|
||||
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopUntilEmptyTriggerReturnsAfterFirstKey(t *testing.T) {
|
||||
installRawInputStub(t, "abc", nil)
|
||||
if err := StopUntil("", "", false); err != nil {
|
||||
t.Fatalf("StopUntil returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseYesNoValue(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@ -40,6 +137,59 @@ func TestParseYesNoValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceFloat32(t *testing.T) {
|
||||
data := InputMsg{msg: "1.5,2.25", err: nil}
|
||||
got, err := data.SliceFloat32(",")
|
||||
if err != nil {
|
||||
t.Fatalf("SliceFloat32 returned error: %v", err)
|
||||
}
|
||||
if len(got) != 2 || got[0] != float32(1.5) || got[1] != float32(2.25) {
|
||||
t.Fatalf("unexpected float32 slice: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypedSliceParsingTrimsTokenWhitespace(t *testing.T) {
|
||||
ints, err := (InputMsg{msg: "1, 2, 3"}).SliceInt(",")
|
||||
if err != nil {
|
||||
t.Fatalf("SliceInt returned error: %v", err)
|
||||
}
|
||||
if len(ints) != 3 || ints[0] != 1 || ints[1] != 2 || ints[2] != 3 {
|
||||
t.Fatalf("unexpected int slice: %#v", ints)
|
||||
}
|
||||
|
||||
bools, err := (InputMsg{msg: "true, false, true"}).SliceBool(",")
|
||||
if err != nil {
|
||||
t.Fatalf("SliceBool returned error: %v", err)
|
||||
}
|
||||
if len(bools) != 3 || !bools[0] || bools[1] || !bools[2] {
|
||||
t.Fatalf("unexpected bool slice: %#v", bools)
|
||||
}
|
||||
|
||||
float64s, err := (InputMsg{msg: "1.25, 2.5, 3.75"}).SliceFloat64(",")
|
||||
if err != nil {
|
||||
t.Fatalf("SliceFloat64 returned error: %v", err)
|
||||
}
|
||||
if len(float64s) != 3 || float64s[0] != 1.25 || float64s[1] != 2.5 || float64s[2] != 3.75 {
|
||||
t.Fatalf("unexpected float64 slice: %#v", float64s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvanceTriggerIndexHandlesOverlap(t *testing.T) {
|
||||
trigger := []rune("aba")
|
||||
prefix := buildTriggerPrefixTable(trigger)
|
||||
index := 0
|
||||
complete := false
|
||||
for _, r := range []rune("aaba") {
|
||||
index, complete = advanceTriggerIndex(trigger, prefix, index, r)
|
||||
if complete {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !complete {
|
||||
t.Fatal("expected overlapped trigger to complete")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Slice(t *testing.T) {
|
||||
var data = InputMsg{
|
||||
msg: "true,false,true,true,false,0,1,hello",
|
||||
|
||||
71
pipe.go
Normal file
71
pipe.go
Normal file
@ -0,0 +1,71 @@
|
||||
package stario
|
||||
|
||||
import "io"
|
||||
|
||||
// StarPipeReader is the read side returned by NewStarPipe.
|
||||
type StarPipeReader struct {
|
||||
buf *StarBuffer
|
||||
}
|
||||
|
||||
// StarPipeWriter is the write side returned by NewStarPipe.
|
||||
type StarPipeWriter struct {
|
||||
buf *StarBuffer
|
||||
}
|
||||
|
||||
// NewStarPipe creates a buffered in-memory pipe backed by StarBuffer.
|
||||
//
|
||||
// The writer side uses Close to signal a graceful end-of-stream and Abort to
|
||||
// fail both sides immediately.
|
||||
func NewStarPipe(capacity uint64) (*StarPipeReader, *StarPipeWriter, error) {
|
||||
buf, err := NewStarBuffer(capacity)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &StarPipeReader{buf: buf}, &StarPipeWriter{buf: buf}, nil
|
||||
}
|
||||
|
||||
// Read reads buffered bytes from the pipe.
|
||||
func (reader *StarPipeReader) Read(p []byte) (int, error) {
|
||||
if reader == nil || reader.buf == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return reader.buf.Read(p)
|
||||
}
|
||||
|
||||
// Close aborts the pipe from the read side and wakes blocked writers.
|
||||
func (reader *StarPipeReader) Close() error {
|
||||
if reader == nil || reader.buf == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return reader.buf.Abort()
|
||||
}
|
||||
|
||||
// Abort aborts the pipe from the read side and wakes blocked writers.
|
||||
func (reader *StarPipeReader) Abort() error {
|
||||
return reader.Close()
|
||||
}
|
||||
|
||||
// Write writes bytes into the pipe buffer.
|
||||
func (writer *StarPipeWriter) Write(p []byte) (int, error) {
|
||||
if writer == nil || writer.buf == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return writer.buf.Write(p)
|
||||
}
|
||||
|
||||
// Close gracefully closes the write side. Buffered bytes remain readable until
|
||||
// drained.
|
||||
func (writer *StarPipeWriter) Close() error {
|
||||
if writer == nil || writer.buf == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return writer.buf.Close()
|
||||
}
|
||||
|
||||
// Abort aborts the pipe immediately.
|
||||
func (writer *StarPipeWriter) Abort() error {
|
||||
if writer == nil || writer.buf == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return writer.buf.Abort()
|
||||
}
|
||||
53
pipe_test.go
Normal file
53
pipe_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStarPipeRoundTrip(t *testing.T) {
|
||||
reader, writer, err := NewStarPipe(16)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []byte("hello world")
|
||||
go func() {
|
||||
_, _ = writer.Write(want)
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("unexpected payload: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarPipeReaderCloseAbortsWriter(t *testing.T) {
|
||||
reader, writer, err := NewStarPipe(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reader.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := writer.Write([]byte("x")); !errors.Is(err, ErrStarBufferClosed) {
|
||||
t.Fatalf("unexpected writer error after reader close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarPipeNilEndsReportClosedPipe(t *testing.T) {
|
||||
var reader *StarPipeReader
|
||||
var writer *StarPipeWriter
|
||||
|
||||
if _, err := reader.Read(make([]byte, 1)); !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("unexpected nil reader error: %v", err)
|
||||
}
|
||||
if _, err := writer.Write([]byte("x")); !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("unexpected nil writer error: %v", err)
|
||||
}
|
||||
}
|
||||
325
que.go
325
que.go
@ -1,325 +0,0 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
|
||||
|
||||
// 识别头
|
||||
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
||||
|
||||
// MsgQueue 为基本的信息单位
|
||||
type MsgQueue struct {
|
||||
ID uint16
|
||||
Msg []byte
|
||||
Conn interface{}
|
||||
}
|
||||
|
||||
// StarQueue 为流数据中的消息队列分发
|
||||
type StarQueue struct {
|
||||
maxLength uint32
|
||||
count int64
|
||||
Encode bool
|
||||
msgID uint16
|
||||
msgPool chan MsgQueue
|
||||
unFinMsg sync.Map
|
||||
lastID int //= -1
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
duration time.Duration
|
||||
EncodeFunc func([]byte) []byte
|
||||
DecodeFunc func([]byte) []byte
|
||||
//restoreMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
|
||||
var q StarQueue
|
||||
q.Encode = false
|
||||
q.count = count
|
||||
q.maxLength = maxMsgLength
|
||||
q.msgPool = make(chan MsgQueue, count)
|
||||
if ctx == nil {
|
||||
q.ctx, q.cancel = context.WithCancel(context.Background())
|
||||
} else {
|
||||
q.ctx, q.cancel = context.WithCancel(ctx)
|
||||
}
|
||||
q.duration = 0
|
||||
return &q
|
||||
}
|
||||
func NewQueueWithCount(count int64) *StarQueue {
|
||||
return NewQueueCtx(nil, count, 0)
|
||||
}
|
||||
|
||||
// NewQueue 建立一个新消息队列
|
||||
func NewQueue() *StarQueue {
|
||||
return NewQueueWithCount(32)
|
||||
}
|
||||
|
||||
// Uint32ToByte 4位uint32转byte
|
||||
func Uint32ToByte(src uint32) []byte {
|
||||
res := make([]byte, 4)
|
||||
res[3] = uint8(src)
|
||||
res[2] = uint8(src >> 8)
|
||||
res[1] = uint8(src >> 16)
|
||||
res[0] = uint8(src >> 24)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint32 byte转4位uint32
|
||||
func ByteToUint32(src []byte) uint32 {
|
||||
var res uint32
|
||||
buffer := bytes.NewBuffer(src)
|
||||
binary.Read(buffer, binary.BigEndian, &res)
|
||||
return res
|
||||
}
|
||||
|
||||
// Uint16ToByte 2位uint16转byte
|
||||
func Uint16ToByte(src uint16) []byte {
|
||||
res := make([]byte, 2)
|
||||
res[1] = uint8(src)
|
||||
res[0] = uint8(src >> 8)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint16 用于byte转uint16
|
||||
func ByteToUint16(src []byte) uint16 {
|
||||
var res uint16
|
||||
buffer := bytes.NewBuffer(src)
|
||||
binary.Read(buffer, binary.BigEndian, &res)
|
||||
return res
|
||||
}
|
||||
|
||||
// BuildMessage 生成编码后的信息用于发送
|
||||
func (q *StarQueue) BuildMessage(src []byte) []byte {
|
||||
var buff bytes.Buffer
|
||||
q.msgID++
|
||||
if q.Encode {
|
||||
src = q.EncodeFunc(src)
|
||||
}
|
||||
length := uint32(len(src))
|
||||
buff.Write(header)
|
||||
buff.Write(Uint32ToByte(length))
|
||||
buff.Write(Uint16ToByte(q.msgID))
|
||||
buff.Write(src)
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
// BuildHeader 生成编码后的Header用于发送
|
||||
func (q *StarQueue) BuildHeader(length uint32) []byte {
|
||||
var buff bytes.Buffer
|
||||
q.msgID++
|
||||
buff.Write(header)
|
||||
buff.Write(Uint32ToByte(length))
|
||||
buff.Write(Uint16ToByte(q.msgID))
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
type unFinMsg struct {
|
||||
ID uint16
|
||||
LengthRecv uint32
|
||||
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
|
||||
HeaderMsg []byte
|
||||
RecvMsg []byte
|
||||
}
|
||||
|
||||
func (q *StarQueue) push2list(msg MsgQueue) {
|
||||
q.msgPool <- msg
|
||||
}
|
||||
|
||||
// ParseMessage 用于解析收到的msg信息
|
||||
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
|
||||
// parseMessage 用于解析收到的msg信息
|
||||
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
|
||||
tmp, ok := q.unFinMsg.Load(conn)
|
||||
if ok { //存在未完成的信息
|
||||
lastMsg := tmp.(*unFinMsg)
|
||||
headerLen := len(lastMsg.HeaderMsg)
|
||||
if headerLen < 14 { //未完成头标题
|
||||
//传输的数据不能填充header头
|
||||
if len(msg) < 14-headerLen {
|
||||
//加入header头并退出
|
||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
|
||||
q.unFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
//获取14字节完整的header
|
||||
header := msg[0 : 14-headerLen]
|
||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
|
||||
//检查收到的header是否为认证header
|
||||
//若不是,丢弃并重新来过
|
||||
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
|
||||
q.unFinMsg.Delete(conn)
|
||||
if len(msg) == 0 {
|
||||
return nil
|
||||
}
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
//获得本数据包长度
|
||||
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
|
||||
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
|
||||
q.unFinMsg.Delete(conn)
|
||||
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
|
||||
}
|
||||
//获得本数据包ID
|
||||
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
|
||||
//存入列表
|
||||
q.unFinMsg.Store(conn, lastMsg)
|
||||
msg = msg[14-headerLen:]
|
||||
if uint32(len(msg)) < lastMsg.LengthRecv {
|
||||
lastMsg.RecvMsg = msg
|
||||
q.unFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
if uint32(len(msg)) >= lastMsg.LengthRecv {
|
||||
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
|
||||
if q.Encode {
|
||||
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
|
||||
}
|
||||
msg = msg[lastMsg.LengthRecv:]
|
||||
storeMsg := MsgQueue{
|
||||
ID: lastMsg.ID,
|
||||
Msg: lastMsg.RecvMsg,
|
||||
Conn: conn,
|
||||
}
|
||||
//q.restoreMu.Lock()
|
||||
q.push2list(storeMsg)
|
||||
//q.restoreMu.Unlock()
|
||||
q.unFinMsg.Delete(conn)
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
} else {
|
||||
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
|
||||
if lastID < 0 {
|
||||
q.unFinMsg.Delete(conn)
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
if len(msg) >= lastID {
|
||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
|
||||
if q.Encode {
|
||||
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
|
||||
}
|
||||
storeMsg := MsgQueue{
|
||||
ID: lastMsg.ID,
|
||||
Msg: lastMsg.RecvMsg,
|
||||
Conn: conn,
|
||||
}
|
||||
q.push2list(storeMsg)
|
||||
q.unFinMsg.Delete(conn)
|
||||
if len(msg) == lastID {
|
||||
return nil
|
||||
}
|
||||
msg = msg[lastID:]
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
|
||||
q.unFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
return nil
|
||||
}
|
||||
var start int
|
||||
if start = searchHeader(msg); start == -1 {
|
||||
return errors.New("data format error")
|
||||
}
|
||||
msg = msg[start:]
|
||||
lastMsg := unFinMsg{}
|
||||
q.unFinMsg.Store(conn, &lastMsg)
|
||||
return q.parseMessage(msg, conn)
|
||||
}
|
||||
|
||||
func checkHeader(msg []byte) bool {
|
||||
if len(msg) != 8 {
|
||||
return false
|
||||
}
|
||||
for k, v := range msg {
|
||||
if v != header[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func searchHeader(msg []byte) int {
|
||||
if len(msg) < 8 {
|
||||
return 0
|
||||
}
|
||||
for k, v := range msg {
|
||||
find := 0
|
||||
if v == header[0] {
|
||||
for k2, v2 := range header {
|
||||
if msg[k+k2] == v2 {
|
||||
find++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if find == 8 {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func bytesMerge(src ...[]byte) []byte {
|
||||
var buff bytes.Buffer
|
||||
for _, v := range src {
|
||||
buff.Write(v)
|
||||
}
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
// Restore 获取收到的信息
|
||||
func (q *StarQueue) Restore() (MsgQueue, error) {
|
||||
if q.duration.Seconds() == 0 {
|
||||
q.duration = 86400 * time.Second
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return MsgQueue{}, errors.New("Stoped By External Function Call")
|
||||
case <-time.After(q.duration):
|
||||
if q.duration != 0 {
|
||||
return MsgQueue{}, ErrDeadlineExceeded
|
||||
}
|
||||
case data, ok := <-q.msgPool:
|
||||
if !ok {
|
||||
return MsgQueue{}, os.ErrClosed
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RestoreOne 获取收到的一个信息
|
||||
// 兼容性修改
|
||||
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
|
||||
return q.Restore()
|
||||
}
|
||||
|
||||
// Stop 立即停止Restore
|
||||
func (q *StarQueue) Stop() {
|
||||
q.cancel()
|
||||
}
|
||||
|
||||
// RestoreDuration Restore最大超时时间
|
||||
func (q *StarQueue) RestoreDuration(tm time.Duration) {
|
||||
q.duration = tm
|
||||
}
|
||||
|
||||
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
|
||||
return q.msgPool
|
||||
}
|
||||
80
que_benchmark_test.go
Normal file
80
que_benchmark_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkQueueBuildMessage64KiB(b *testing.B) {
|
||||
que := NewQueue()
|
||||
payload := make([]byte, 64*1024)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = que.BuildMessage(payload)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueueWriteFrame64KiB(b *testing.B) {
|
||||
que := NewQueue()
|
||||
payload := make([]byte, 64*1024)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := que.WriteFrame(io.Discard, payload); err != nil {
|
||||
b.Fatalf("WriteFrame failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueueWriteFrameBuffers64KiB(b *testing.B) {
|
||||
que := NewQueue()
|
||||
payload := make([]byte, 64*1024)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := que.WriteFrameBuffers(io.Discard, payload); err != nil {
|
||||
b.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueueWriteFramesBuffers4x64KiB(b *testing.B) {
|
||||
que := NewQueue()
|
||||
payloads := [][]byte{
|
||||
make([]byte, 64*1024),
|
||||
make([]byte, 64*1024),
|
||||
make([]byte, 64*1024),
|
||||
make([]byte, 64*1024),
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(payloads) * len(payloads[0])))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := que.WriteFramesBuffers(io.Discard, payloads...); err != nil {
|
||||
b.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueueParseRestoreHello(b *testing.B) {
|
||||
que := NewQueueWithCount(1)
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
if len(frame) == 0 {
|
||||
b.Fatal("BuildMessage returned empty frame")
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(frame)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := que.ParseMessage(frame, "bench"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := que.Restore(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
34
que_bytes.go
Normal file
34
que_bytes.go
Normal file
@ -0,0 +1,34 @@
|
||||
package stario
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Uint32ToByte 4位uint32转byte。
|
||||
func Uint32ToByte(src uint32) []byte {
|
||||
res := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(res, src)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint32 byte转4位uint32。
|
||||
func ByteToUint32(src []byte) uint32 {
|
||||
return binary.BigEndian.Uint32(src)
|
||||
}
|
||||
|
||||
// Uint16ToByte 2位uint16转byte。
|
||||
func Uint16ToByte(src uint16) []byte {
|
||||
res := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(res, src)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint16 用于byte转uint16。
|
||||
func ByteToUint16(src []byte) uint16 {
|
||||
return binary.BigEndian.Uint16(src)
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append([]byte(nil), src...)
|
||||
}
|
||||
208
que_frame.go
Normal file
208
que_frame.go
Normal file
@ -0,0 +1,208 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// BuildMessage builds one full frame and panics if the payload is too large to
|
||||
// fit in the framing format.
|
||||
//
|
||||
// New code should prefer BuildMessageErr when it needs an explicit error path.
|
||||
func (q *StarQueue) BuildMessage(src []byte) []byte {
|
||||
frame, err := q.BuildMessageErr(src)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return frame
|
||||
}
|
||||
|
||||
// BuildMessageErr builds one full frame and returns an explicit error when the
|
||||
// payload exceeds the 32-bit frame length.
|
||||
func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) {
|
||||
payload := src
|
||||
if q.Encode && q.EncodeFunc != nil {
|
||||
payload = q.EncodeFunc(payload)
|
||||
}
|
||||
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
header := q.BuildHeader(length)
|
||||
frame := make([]byte, 0, len(header)+len(payload))
|
||||
frame = append(frame, header...)
|
||||
frame = append(frame, payload...)
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// WriteFrame writes one framed payload directly to w without building an
|
||||
// intermediate full-frame slice first.
|
||||
func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error {
|
||||
if w == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
payload := src
|
||||
if q.Encode && q.EncodeFunc != nil {
|
||||
payload = q.EncodeFunc(payload)
|
||||
}
|
||||
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var header [queHeaderSize]byte
|
||||
writeHeaderBytes(header[:], queHeader{
|
||||
Length: length,
|
||||
Version: queVersionV1,
|
||||
Flags: queSupportedFlags,
|
||||
})
|
||||
if err := writeFull(w, header[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeFull(w, payload)
|
||||
}
|
||||
|
||||
// WriteFrameBuffers writes one framed payload using net.Buffers so callers can
|
||||
// opt into gather writes when the underlying writer supports it well.
|
||||
func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error {
|
||||
if w == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
payload := src
|
||||
if q.Encode && q.EncodeFunc != nil {
|
||||
payload = q.EncodeFunc(payload)
|
||||
}
|
||||
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var header [queHeaderSize]byte
|
||||
writeHeaderBytes(header[:], queHeader{
|
||||
Length: length,
|
||||
Version: queVersionV1,
|
||||
Flags: queSupportedFlags,
|
||||
})
|
||||
buffers := net.Buffers{header[:], payload}
|
||||
_, err = buffers.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteFramesBuffers writes multiple framed payloads using one net.Buffers
|
||||
// batch so callers can reduce write calls on stream transports.
|
||||
func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error {
|
||||
if w == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
if len(payloads) == 0 {
|
||||
return nil
|
||||
}
|
||||
buffers := make(net.Buffers, 0, len(payloads)*2)
|
||||
headers := make([][queHeaderSize]byte, len(payloads))
|
||||
for i, src := range payloads {
|
||||
payload := src
|
||||
if q.Encode && q.EncodeFunc != nil {
|
||||
payload = q.EncodeFunc(payload)
|
||||
}
|
||||
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writeHeaderBytes(headers[i][:], queHeader{
|
||||
Length: length,
|
||||
Version: queVersionV1,
|
||||
Flags: queSupportedFlags,
|
||||
})
|
||||
buffers = append(buffers, headers[i][:], payload)
|
||||
}
|
||||
_, err := buffers.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// BuildHeader 生成编码后的Header用于发送。
|
||||
func (q *StarQueue) BuildHeader(length uint32) []byte {
|
||||
return buildHeaderBytes(queHeader{
|
||||
Length: length,
|
||||
Version: queVersionV1,
|
||||
Flags: queSupportedFlags,
|
||||
})
|
||||
}
|
||||
|
||||
func buildHeaderBytes(header queHeader) []byte {
|
||||
buf := make([]byte, queHeaderSize)
|
||||
writeHeaderBytes(buf, header)
|
||||
return buf
|
||||
}
|
||||
|
||||
func writeHeaderBytes(dst []byte, header queHeader) {
|
||||
if len(dst) < queHeaderSize {
|
||||
return
|
||||
}
|
||||
copy(dst[:queMagicSize], queMagic)
|
||||
binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length)
|
||||
dst[12] = header.Version
|
||||
dst[13] = header.Flags
|
||||
}
|
||||
|
||||
func payloadSizeToUint32(size uint64) (uint32, error) {
|
||||
const maxFramePayload = ^uint32(0)
|
||||
if size > uint64(maxFramePayload) {
|
||||
return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload))
|
||||
}
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, data []byte) error {
|
||||
for len(data) > 0 {
|
||||
n, err := w.Write(data)
|
||||
if n > 0 {
|
||||
data = data[n:]
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) {
|
||||
if len(src) < queHeaderSize {
|
||||
return queHeader{}, ErrQueueDataFormat
|
||||
}
|
||||
if !equalMagic(src[:queMagicSize]) {
|
||||
return queHeader{}, ErrQueueDataFormat
|
||||
}
|
||||
|
||||
header := queHeader{
|
||||
Length: ByteToUint32(src[queMagicSize : queMagicSize+4]),
|
||||
Version: src[12],
|
||||
Flags: src[13],
|
||||
}
|
||||
|
||||
if header.Version != queVersionV1 {
|
||||
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version)
|
||||
}
|
||||
if header.Flags != queSupportedFlags {
|
||||
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags)
|
||||
}
|
||||
if maxLength != 0 && header.Length > maxLength {
|
||||
return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func equalMagic(src []byte) bool {
|
||||
if len(src) != queMagicSize {
|
||||
return false
|
||||
}
|
||||
for i, b := range src {
|
||||
if b != queMagic[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
237
que_parse.go
Normal file
237
que_parse.go
Normal file
@ -0,0 +1,237 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ParseMessage 用于解析收到的msg信息。
|
||||
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
||||
return q.parseMessage(msg, conn, false, func(payload []byte) error {
|
||||
return q.push2list(MsgQueue{
|
||||
Msg: payload,
|
||||
Conn: conn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ParseMessageView parses frames and exposes each payload to fn without
|
||||
// forcing StarQueue to clone it first.
|
||||
//
|
||||
// The provided payload is only valid during the callback. If q uses the legacy
|
||||
// DecodeFunc path, StarQueue still has to allocate a decoded payload first.
|
||||
func (q *StarQueue) ParseMessageView(msg []byte, conn interface{}, fn func(FrameView) error) error {
|
||||
if fn == nil {
|
||||
return ErrQueueFrameHandlerNil
|
||||
}
|
||||
return q.parseMessage(msg, conn, true, func(payload []byte) error {
|
||||
return fn(FrameView{
|
||||
Payload: payload,
|
||||
Conn: conn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ParseMessageOwned parses frames and emits owned payload copies to fn without
|
||||
// routing them through RestoreChan.
|
||||
//
|
||||
// Compared with ParseMessage, this keeps StarQueue state handling the same but
|
||||
// lets callers decide how to dispatch parsed messages themselves.
|
||||
func (q *StarQueue) ParseMessageOwned(msg []byte, conn interface{}, fn func(MsgQueue) error) error {
|
||||
if fn == nil {
|
||||
return ErrQueueFrameHandlerNil
|
||||
}
|
||||
frames := make([]MsgQueue, 0, 1)
|
||||
parseErr := q.parseMessage(msg, conn, false, func(payload []byte) error {
|
||||
frames = append(frames, MsgQueue{
|
||||
Msg: payload,
|
||||
Conn: conn,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
for _, frame := range frames {
|
||||
if err := fn(frame); err != nil {
|
||||
if parseErr != nil {
|
||||
return fmt.Errorf("%v: %w", parseErr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return parseErr
|
||||
}
|
||||
|
||||
func (q *StarQueue) parseMessage(msg []byte, conn interface{}, borrowPayload bool, emit func([]byte) error) error {
|
||||
state, err := q.connState(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var firstErr error
|
||||
for {
|
||||
payload, ok, err := q.nextPayload(state, conn, msg, borrowPayload)
|
||||
msg = nil
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if err := emit(payload); err != nil {
|
||||
if firstErr != nil {
|
||||
return fmt.Errorf("%v: %w", firstErr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (q *StarQueue) nextPayload(state *queConnState, conn interface{}, msg []byte, borrowPayload bool) ([]byte, bool, error) {
|
||||
state.mu.Lock()
|
||||
defer state.mu.Unlock()
|
||||
|
||||
if len(msg) != 0 {
|
||||
state.buf = append(state.buf, msg...)
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
for {
|
||||
synced, err := syncFrameStart(&state.buf)
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if !synced {
|
||||
if len(state.buf) == 0 {
|
||||
q.states.Delete(conn)
|
||||
}
|
||||
return nil, false, firstErr
|
||||
}
|
||||
if len(state.buf) < queHeaderSize {
|
||||
return nil, false, firstErr
|
||||
}
|
||||
|
||||
header, err := parseHeaderBytes(state.buf[:queHeaderSize], q.maxLength)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
state.buf = shrinkBuffer(state.buf[1:])
|
||||
continue
|
||||
}
|
||||
|
||||
frameLen := queHeaderSize + int(header.Length)
|
||||
if len(state.buf) < frameLen {
|
||||
return nil, false, firstErr
|
||||
}
|
||||
|
||||
payload, rest := extractPayload(state.buf, frameLen, borrowPayload && !(q.Encode && q.DecodeFunc != nil))
|
||||
state.buf = rest
|
||||
if q.Encode && q.DecodeFunc != nil {
|
||||
payload = q.DecodeFunc(payload)
|
||||
}
|
||||
if len(state.buf) == 0 {
|
||||
q.states.Delete(conn)
|
||||
}
|
||||
return payload, true, firstErr
|
||||
}
|
||||
}
|
||||
|
||||
func extractPayload(buf []byte, frameLen int, borrowPayload bool) ([]byte, []byte) {
|
||||
payload := buf[queHeaderSize:frameLen]
|
||||
if !borrowPayload {
|
||||
return cloneBytes(payload), shrinkBuffer(buf[frameLen:])
|
||||
}
|
||||
if frameLen == len(buf) {
|
||||
return payload, nil
|
||||
}
|
||||
return payload, cloneBytes(buf[frameLen:])
|
||||
}
|
||||
|
||||
func (q *StarQueue) push2list(msg MsgQueue) error {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return q.ctx.Err()
|
||||
default:
|
||||
}
|
||||
q.sendMu.RLock()
|
||||
defer q.sendMu.RUnlock()
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return q.ctx.Err()
|
||||
case q.msgPool <- msg:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateConnKey(conn interface{}) error {
|
||||
if conn == nil {
|
||||
return ErrQueueConnKeyNil
|
||||
}
|
||||
typ := reflect.TypeOf(conn)
|
||||
if typ != nil && !typ.Comparable() {
|
||||
return ErrQueueConnKeyInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *StarQueue) connState(conn interface{}) (*queConnState, error) {
|
||||
if err := validateConnKey(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, _ := q.states.LoadOrStore(conn, &queConnState{})
|
||||
return state.(*queConnState), nil
|
||||
}
|
||||
|
||||
func syncFrameStart(buf *[]byte) (bool, error) {
|
||||
if len(*buf) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(*buf) >= queMagicSize && equalMagic((*buf)[:queMagicSize]) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
idx := bytes.Index(*buf, queMagic)
|
||||
if idx == 0 {
|
||||
return true, nil
|
||||
}
|
||||
if idx > 0 {
|
||||
*buf = cloneBytes((*buf)[idx:])
|
||||
return true, ErrQueueDataFormat
|
||||
}
|
||||
|
||||
keep := trailingMagicPrefixLen(*buf)
|
||||
if keep == len(*buf) {
|
||||
return false, nil
|
||||
}
|
||||
if keep > 0 {
|
||||
*buf = cloneBytes((*buf)[len(*buf)-keep:])
|
||||
return false, ErrQueueDataFormat
|
||||
}
|
||||
|
||||
*buf = (*buf)[:0]
|
||||
return false, ErrQueueDataFormat
|
||||
}
|
||||
|
||||
func trailingMagicPrefixLen(buf []byte) int {
|
||||
max := len(buf)
|
||||
if max > queMagicSize-1 {
|
||||
max = queMagicSize - 1
|
||||
}
|
||||
for keep := max; keep > 0; keep-- {
|
||||
if bytes.Equal(buf[len(buf)-keep:], queMagic[:keep]) {
|
||||
return keep
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func shrinkBuffer(buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return nil
|
||||
}
|
||||
if cap(buf) > len(buf)*4 {
|
||||
return cloneBytes(buf)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
80
que_runtime.go
Normal file
80
que_runtime.go
Normal file
@ -0,0 +1,80 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (q *StarQueue) closedRestoreErr() error {
|
||||
if err := q.ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
// Restore blocks until one message is available, the queue is stopped, or the
|
||||
// configured timeout expires.
|
||||
func (q *StarQueue) Restore() (MsgQueue, error) {
|
||||
if q.duration <= 0 {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return MsgQueue{}, q.ctx.Err()
|
||||
case data, ok := <-q.msgPool:
|
||||
if !ok {
|
||||
return MsgQueue{}, q.closedRestoreErr()
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
timer := time.NewTimer(q.duration)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return MsgQueue{}, q.ctx.Err()
|
||||
case <-timer.C:
|
||||
return MsgQueue{}, ErrDeadlineExceeded
|
||||
case data, ok := <-q.msgPool:
|
||||
if !ok {
|
||||
return MsgQueue{}, q.closedRestoreErr()
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// RestoreOne 获取收到的一个信息。
|
||||
// 兼容性修改。
|
||||
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
|
||||
return q.Restore()
|
||||
}
|
||||
|
||||
// Stop cancels the queue runtime.
|
||||
//
|
||||
// After Stop returns, Restore unblocks with context.Canceled and RestoreChan is
|
||||
// eventually closed.
|
||||
func (q *StarQueue) Stop() {
|
||||
q.cancel()
|
||||
q.shutdown()
|
||||
}
|
||||
|
||||
func (q *StarQueue) shutdown() {
|
||||
q.stopOnce.Do(func() {
|
||||
q.sendMu.Lock()
|
||||
close(q.msgPool)
|
||||
q.sendMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
// RestoreDuration sets the Restore timeout. A non-positive duration means wait
|
||||
// forever until a message arrives or Stop is called.
|
||||
func (q *StarQueue) RestoreDuration(tm time.Duration) {
|
||||
q.duration = tm
|
||||
}
|
||||
|
||||
// RestoreChan exposes the parsed message stream.
|
||||
//
|
||||
// The returned channel is closed after Stop or when the queue context is
|
||||
// canceled. New code should still prefer Restore when it needs timeout/error
|
||||
// classification instead of a plain stream.
|
||||
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
|
||||
return q.msgPool
|
||||
}
|
||||
544
que_test.go
544
que_test.go
@ -1,42 +1,520 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_QueSpeed(t *testing.T) {
|
||||
que := NewQueueWithCount(0)
|
||||
stop := make(chan struct{}, 1)
|
||||
que.RestoreDuration(time.Second * 10)
|
||||
var count int64
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
//fmt.Println(count)
|
||||
return
|
||||
default:
|
||||
}
|
||||
_, err := que.RestoreOne()
|
||||
if err == nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}()
|
||||
cp := 0
|
||||
stoped := time.After(time.Second * 10)
|
||||
data := que.BuildMessage([]byte("hello"))
|
||||
for {
|
||||
select {
|
||||
case <-stoped:
|
||||
fmt.Println(count, cp)
|
||||
stop <- struct{}{}
|
||||
return
|
||||
default:
|
||||
que.ParseMessage(data, "lala")
|
||||
cp++
|
||||
}
|
||||
func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
|
||||
if len(frame) != queHeaderSize+5 {
|
||||
t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5)
|
||||
}
|
||||
if !bytes.Equal(frame[:queMagicSize], queMagic) {
|
||||
t.Fatalf("unexpected magic: %v", frame[:queMagicSize])
|
||||
}
|
||||
if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 {
|
||||
t.Fatalf("unexpected payload length: got %d want 5", got)
|
||||
}
|
||||
if frame[12] != queVersionV1 {
|
||||
t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1)
|
||||
}
|
||||
if frame[13] != queSupportedFlags {
|
||||
t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags)
|
||||
}
|
||||
if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) {
|
||||
t.Fatalf("unexpected payload: %q", frame[queHeaderSize:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) {
|
||||
que := NewQueue()
|
||||
want := que.BuildMessage([]byte("hello"))
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrame failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFrame mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) {
|
||||
que := NewQueue()
|
||||
want := que.BuildMessage([]byte("hello"))
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) {
|
||||
que := NewQueue()
|
||||
payloads := [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("world"),
|
||||
[]byte("batch"),
|
||||
}
|
||||
|
||||
var want []byte
|
||||
for _, payload := range payloads {
|
||||
want = append(want, que.BuildMessage(payload)...)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
|
||||
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) {
|
||||
que := NewQueue()
|
||||
que.Encode = true
|
||||
que.EncodeFunc = bytes.ToUpper
|
||||
want := que.BuildMessage([]byte("hello"))
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrame failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) {
|
||||
que := NewQueue()
|
||||
que.Encode = true
|
||||
que.EncodeFunc = bytes.ToUpper
|
||||
want := que.BuildMessage([]byte("hello"))
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
|
||||
t.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) {
|
||||
que := NewQueue()
|
||||
que.Encode = true
|
||||
que.EncodeFunc = bytes.ToUpper
|
||||
payloads := [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("batch"),
|
||||
}
|
||||
|
||||
var want []byte
|
||||
for _, payload := range payloads {
|
||||
want = append(want, que.BuildMessage(payload)...)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
|
||||
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||
}
|
||||
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||
t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageSplitAcrossCalls(t *testing.T) {
|
||||
que := NewQueueWithCount(1)
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
|
||||
if err := que.ParseMessage(frame[:3], "split"); err != nil {
|
||||
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||
}
|
||||
if err := que.ParseMessage(frame[3:11], "split"); err != nil {
|
||||
t.Fatalf("unexpected error on partial header: %v", err)
|
||||
}
|
||||
if err := que.ParseMessage(frame[11:], "split"); err != nil {
|
||||
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-que.RestoreChan():
|
||||
if data.ID != 0 {
|
||||
t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID)
|
||||
}
|
||||
if data.Conn != "split" {
|
||||
t.Fatalf("unexpected conn: %#v", data.Conn)
|
||||
}
|
||||
if !bytes.Equal(data.Msg, []byte("hello")) {
|
||||
t.Fatalf("unexpected payload: %q", data.Msg)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("did not restore parsed frame")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
var got [][]byte
|
||||
|
||||
handler := func(view FrameView) error {
|
||||
got = append(got, cloneBytes(view.Payload))
|
||||
if view.Conn != "split-view" {
|
||||
t.Fatalf("unexpected conn: %#v", view.Conn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil {
|
||||
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||
}
|
||||
if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil {
|
||||
t.Fatalf("unexpected error on partial header: %v", err)
|
||||
}
|
||||
if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil {
|
||||
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||
}
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("unexpected frame count: got %d want 1", len(got))
|
||||
}
|
||||
if !bytes.Equal(got[0], []byte("hello")) {
|
||||
t.Fatalf("unexpected payload: %q", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
var got []MsgQueue
|
||||
|
||||
handler := func(msg MsgQueue) error {
|
||||
got = append(got, MsgQueue{
|
||||
Msg: cloneBytes(msg.Msg),
|
||||
Conn: msg.Conn,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil {
|
||||
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||
}
|
||||
if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil {
|
||||
t.Fatalf("unexpected error on partial header: %v", err)
|
||||
}
|
||||
if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil {
|
||||
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||
}
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("unexpected frame count: got %d want 1", len(got))
|
||||
}
|
||||
if got[0].Conn != "split-owned" {
|
||||
t.Fatalf("unexpected conn: %#v", got[0].Conn)
|
||||
}
|
||||
if !bytes.Equal(got[0].Msg, []byte("hello")) {
|
||||
t.Fatalf("unexpected payload: %q", got[0].Msg)
|
||||
}
|
||||
select {
|
||||
case msg := <-que.RestoreChan():
|
||||
t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) {
|
||||
que := NewQueueWithCount(1)
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
|
||||
err := que.ParseMessage(append([]byte("junk"), frame...), "garbage")
|
||||
if !errors.Is(err, ErrQueueDataFormat) {
|
||||
t.Fatalf("expected data format error, got %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-que.RestoreChan():
|
||||
if !bytes.Equal(data.Msg, []byte("hello")) {
|
||||
t.Fatalf("unexpected payload after resync: %q", data.Msg)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("did not restore frame after skipping garbage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
var got []byte
|
||||
|
||||
err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error {
|
||||
got = cloneBytes(view.Payload)
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, ErrQueueDataFormat) {
|
||||
t.Fatalf("expected data format error, got %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, []byte("hello")) {
|
||||
t.Fatalf("unexpected payload after resync: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageViewNilHandler(t *testing.T) {
|
||||
que := NewQueue()
|
||||
err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil)
|
||||
if !errors.Is(err, ErrQueueFrameHandlerNil) {
|
||||
t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageOwnedNilHandler(t *testing.T) {
|
||||
que := NewQueue()
|
||||
err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil)
|
||||
if !errors.Is(err, ErrQueueFrameHandlerNil) {
|
||||
t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameNilWriter(t *testing.T) {
|
||||
que := NewQueue()
|
||||
err := que.WriteFrame(nil, []byte("hello"))
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFrameBuffersNilWriter(t *testing.T) {
|
||||
que := NewQueue()
|
||||
err := que.WriteFrameBuffers(nil, []byte("hello"))
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueWriteFramesBuffersNilWriter(t *testing.T) {
|
||||
que := NewQueue()
|
||||
err := que.WriteFramesBuffers(nil, []byte("hello"))
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) {
|
||||
que := NewQueueWithCount(1)
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
frame[12] = 2
|
||||
|
||||
err := que.ParseMessage(frame, "version")
|
||||
if !errors.Is(err, ErrQueueUnsupportedVersion) {
|
||||
t.Fatalf("expected unsupported version error, got %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-que.RestoreChan():
|
||||
t.Fatalf("unexpected restored frame: %#v", data)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) {
|
||||
que := NewQueueCtx(nil, 1, 4)
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
|
||||
err := que.ParseMessage(frame, "large")
|
||||
if !errors.Is(err, ErrQueueMessageTooLarge) {
|
||||
t.Fatalf("expected message too large error, got %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-que.RestoreChan():
|
||||
t.Fatalf("unexpected restored frame: %#v", data)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
err := que.ParseMessage(frame, []byte("not-comparable"))
|
||||
if !errors.Is(err, ErrQueueConnKeyInvalid) {
|
||||
t.Fatalf("expected invalid conn key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageRejectsNilConnKey(t *testing.T) {
|
||||
que := NewQueue()
|
||||
frame := que.BuildMessage([]byte("hello"))
|
||||
err := que.ParseMessage(frame, nil)
|
||||
if !errors.Is(err, ErrQueueConnKeyNil) {
|
||||
t.Fatalf("expected nil conn key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) {
|
||||
_, err := payloadSizeToUint32(uint64(^uint32(0)) + 1)
|
||||
if !errors.Is(err, ErrQueueMessageTooLarge) {
|
||||
t.Fatalf("expected message too large error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) {
|
||||
que := NewQueueWithCount(1)
|
||||
que.RestoreDuration(0)
|
||||
|
||||
type restoreResult struct {
|
||||
msg MsgQueue
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan restoreResult, 1)
|
||||
go func() {
|
||||
msg, err := que.Restore()
|
||||
resultCh <- restoreResult{msg: msg, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
t.Fatalf("Restore returned too early: %#v", result)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil {
|
||||
t.Fatalf("ParseMessage failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
t.Fatalf("Restore returned error: %v", result.err)
|
||||
}
|
||||
if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) {
|
||||
t.Fatalf("unexpected restore result: %#v", result.msg)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Restore did not return after message arrival")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) {
|
||||
que := NewQueue()
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := que.Restore()
|
||||
resultCh <- err
|
||||
}()
|
||||
|
||||
que.Stop()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Restore did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueRestoreChanClosesOnStop(t *testing.T) {
|
||||
que := NewQueue()
|
||||
resultCh := make(chan bool, 1)
|
||||
go func() {
|
||||
_, ok := <-que.RestoreChan()
|
||||
resultCh <- ok
|
||||
}()
|
||||
|
||||
que.Stop()
|
||||
|
||||
select {
|
||||
case ok := <-resultCh:
|
||||
if ok {
|
||||
t.Fatal("expected RestoreChan to close after Stop")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("RestoreChan did not close after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
que := NewQueueCtx(ctx, 1, 0)
|
||||
resultCh := make(chan bool, 1)
|
||||
go func() {
|
||||
_, ok := <-que.RestoreChan()
|
||||
resultCh <- ok
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case ok := <-resultCh:
|
||||
if ok {
|
||||
t.Fatal("expected RestoreChan to close after context cancel")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("RestoreChan did not close after context cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) {
|
||||
que := NewQueueWithCount(1)
|
||||
if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil {
|
||||
t.Fatalf("ParseMessage first failed: %v", err)
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("ParseMessage returned before Stop: %v", err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
que.Stop()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("ParseMessage did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) {
|
||||
que := NewQueue()
|
||||
reentered := false
|
||||
|
||||
err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error {
|
||||
if !bytes.Equal(view.Payload, []byte("outer")) {
|
||||
t.Fatalf("unexpected outer payload: %q", view.Payload)
|
||||
}
|
||||
return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error {
|
||||
reentered = true
|
||||
if !bytes.Equal(inner.Payload, []byte("inner")) {
|
||||
t.Fatalf("unexpected inner payload: %q", inner.Payload)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMessageView failed: %v", err)
|
||||
}
|
||||
if !reentered {
|
||||
t.Fatal("expected reentrant ParseMessageView to run")
|
||||
}
|
||||
}
|
||||
|
||||
101
que_types.go
Normal file
101
que_types.go
Normal file
@ -0,0 +1,101 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrDeadlineExceeded = errors.New("deadline exceeded")
|
||||
var ErrQueueDataFormat = errors.New("data format error")
|
||||
var ErrQueueUnsupportedVersion = errors.New("unsupported frame version")
|
||||
var ErrQueueUnsupportedFlags = errors.New("unsupported frame flags")
|
||||
var ErrQueueMessageTooLarge = errors.New("message too large")
|
||||
var ErrQueueFrameHandlerNil = errors.New("frame handler is nil")
|
||||
var ErrQueueConnKeyInvalid = errors.New("queue conn key must be comparable")
|
||||
var ErrQueueConnKeyNil = errors.New("queue conn key must not be nil")
|
||||
|
||||
const (
|
||||
queMagicSize = 8
|
||||
queHeaderSize = 14
|
||||
queVersionV1 = 1
|
||||
queSupportedFlags = 0
|
||||
)
|
||||
|
||||
// 识别头
|
||||
var queMagic = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
||||
|
||||
// MsgQueue 为基本的信息单位。
|
||||
type MsgQueue struct {
|
||||
// Deprecated: frame-level IDs are no longer emitted by StarQueue v2 framing.
|
||||
ID uint16
|
||||
|
||||
Msg []byte
|
||||
Conn interface{}
|
||||
}
|
||||
|
||||
// FrameView exposes a parsed payload without forcing StarQueue to clone it.
|
||||
//
|
||||
// The payload is only valid during the ParseMessageView callback. Callers must
|
||||
// copy it if they need to keep it after the callback returns.
|
||||
type FrameView struct {
|
||||
Payload []byte
|
||||
Conn interface{}
|
||||
}
|
||||
|
||||
// StarQueue 为流数据中的消息队列分发。
|
||||
type StarQueue struct {
|
||||
maxLength uint32
|
||||
msgPool chan MsgQueue
|
||||
states sync.Map
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
duration time.Duration
|
||||
sendMu sync.RWMutex
|
||||
stopOnce sync.Once
|
||||
|
||||
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||
Encode bool
|
||||
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||
EncodeFunc func([]byte) []byte
|
||||
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||
DecodeFunc func([]byte) []byte
|
||||
}
|
||||
|
||||
type queHeader struct {
|
||||
Length uint32
|
||||
Version uint8
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
type queConnState struct {
|
||||
mu sync.Mutex
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
|
||||
if count < 0 {
|
||||
panic("stario: negative queue count")
|
||||
}
|
||||
q := &StarQueue{
|
||||
maxLength: maxMsgLength,
|
||||
msgPool: make(chan MsgQueue, count),
|
||||
}
|
||||
if ctx == nil {
|
||||
q.ctx, q.cancel = context.WithCancel(context.Background())
|
||||
} else {
|
||||
q.ctx, q.cancel = context.WithCancel(ctx)
|
||||
}
|
||||
context.AfterFunc(q.ctx, q.shutdown)
|
||||
return q
|
||||
}
|
||||
|
||||
func NewQueueWithCount(count int64) *StarQueue {
|
||||
return NewQueueCtx(nil, count, 0)
|
||||
}
|
||||
|
||||
// NewQueue 建立一个新消息队列。
|
||||
func NewQueue() *StarQueue {
|
||||
return NewQueueWithCount(32)
|
||||
}
|
||||
44
signal_error.go
Normal file
44
signal_error.go
Normal file
@ -0,0 +1,44 @@
|
||||
package stario
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrSignalInterrupt = errors.New("interrupt")
|
||||
ErrSignalStop = errors.New("stop")
|
||||
ErrSignalQuit = errors.New("quit")
|
||||
)
|
||||
|
||||
type inputSignalError struct {
|
||||
msg string
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *inputSignalError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *inputSignalError) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func signalErrorForType(sigtype rune) error {
|
||||
switch sigtype {
|
||||
case 0x03:
|
||||
return &inputSignalError{
|
||||
msg: "SIGNAL SIGINT RECIVED",
|
||||
cause: ErrSignalInterrupt,
|
||||
}
|
||||
case 0x1a:
|
||||
return &inputSignalError{
|
||||
msg: "SIGNAL SIGSTOP RECIVED",
|
||||
cause: ErrSignalStop,
|
||||
}
|
||||
case 0x1c:
|
||||
return &inputSignalError{
|
||||
msg: "SIGNAL SIGQUIT RECIVED",
|
||||
cause: ErrSignalQuit,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
50
signal_error_test.go
Normal file
50
signal_error_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSignalErrorForType(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
sig rune
|
||||
msg string
|
||||
want error
|
||||
}{
|
||||
{name: "interrupt", sig: 0x03, msg: "SIGNAL SIGINT RECIVED", want: ErrSignalInterrupt},
|
||||
{name: "stop", sig: 0x1a, msg: "SIGNAL SIGSTOP RECIVED", want: ErrSignalStop},
|
||||
{name: "quit", sig: 0x1c, msg: "SIGNAL SIGQUIT RECIVED", want: ErrSignalQuit},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := signalErrorForType(tc.sig)
|
||||
if err == nil {
|
||||
t.Fatalf("%s: expected non-nil error", tc.name)
|
||||
}
|
||||
if err.Error() != tc.msg {
|
||||
t.Fatalf("%s: unexpected error text: got %q want %q", tc.name, err.Error(), tc.msg)
|
||||
}
|
||||
if !errors.Is(err, tc.want) {
|
||||
t.Fatalf("%s: errors.Is mismatch: got %v want %v", tc.name, err, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalErrorForUnknownType(t *testing.T) {
|
||||
if err := signalErrorForType('x'); err != nil {
|
||||
t.Fatalf("expected nil error for unknown signal, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalReturnsTypedError(t *testing.T) {
|
||||
err := signal(0x03)
|
||||
if !errors.Is(err, ErrSignalInterrupt) {
|
||||
t.Fatalf("expected interrupt error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalReturnsNilForUnknownType(t *testing.T) {
|
||||
if err := signal('x'); err != nil {
|
||||
t.Fatalf("expected nil error for unknown signal, got %v", err)
|
||||
}
|
||||
}
|
||||
@ -3,20 +3,6 @@
|
||||
|
||||
package stario
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func signal(sigtype rune) error {
|
||||
//todo: use win32api call signal
|
||||
switch sigtype {
|
||||
case 0x03:
|
||||
return errors.New("SIGNAL SIGINT RECIVED")
|
||||
case 0x1a:
|
||||
return errors.New("SIGNAL SIGSTOP RECIVED")
|
||||
case 0x1c:
|
||||
return errors.New("SIGNAL SIGQUIT RECIVED")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return signalErrorForType(sigtype)
|
||||
}
|
||||
|
||||
@ -3,24 +3,6 @@
|
||||
|
||||
package stario
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func signal(sigtype rune) error {
|
||||
switch sigtype {
|
||||
case 0x03:
|
||||
syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
return errors.New("SIGNAL SIGINT RECIVED")
|
||||
case 0x1a:
|
||||
syscall.Kill(os.Getpid(), syscall.SIGSTOP)
|
||||
return errors.New("SIGNAL SIGSTOP RECIVED")
|
||||
case 0x1c:
|
||||
syscall.Kill(os.Getpid(), syscall.SIGQUIT)
|
||||
return errors.New("SIGNAL SIGQUIT RECIVED")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return signalErrorForType(sigtype)
|
||||
}
|
||||
|
||||
123
starring.go
Normal file
123
starring.go
Normal file
@ -0,0 +1,123 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrStarRingInvalidCapacity = errors.New("star ring capacity must be greater than zero")
|
||||
|
||||
// StarRing keeps the newest bytes in memory with fixed capacity.
|
||||
// It never blocks writers: when capacity is exceeded, oldest bytes are dropped.
|
||||
type StarRing struct {
|
||||
mu sync.RWMutex
|
||||
cap int
|
||||
|
||||
size int
|
||||
chunks list.List // each element: []byte
|
||||
}
|
||||
|
||||
func NewStarRing(capacity int) (*StarRing, error) {
|
||||
if capacity <= 0 {
|
||||
return nil, ErrStarRingInvalidCapacity
|
||||
}
|
||||
return &StarRing{cap: capacity}, nil
|
||||
}
|
||||
|
||||
func (s *StarRing) Capacity() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cap
|
||||
}
|
||||
|
||||
func (s *StarRing) Len() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size
|
||||
}
|
||||
|
||||
func (s *StarRing) Reset() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.size = 0
|
||||
s.chunks.Init()
|
||||
}
|
||||
|
||||
func (s *StarRing) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
payload := append([]byte(nil), p...)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.writeLocked(payload)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *StarRing) WriteString(text string) (int, error) {
|
||||
if len(text) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return s.Write([]byte(text))
|
||||
}
|
||||
|
||||
func (s *StarRing) Snapshot() []byte {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
out := make([]byte, s.size)
|
||||
pos := 0
|
||||
for node := s.chunks.Front(); node != nil; node = node.Next() {
|
||||
chunk, ok := node.Value.([]byte)
|
||||
if !ok || len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
pos += copy(out[pos:], chunk)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *StarRing) writeLocked(payload []byte) {
|
||||
if len(payload) >= s.cap {
|
||||
payload = payload[len(payload)-s.cap:]
|
||||
s.chunks.Init()
|
||||
s.chunks.PushBack(payload)
|
||||
s.size = len(payload)
|
||||
return
|
||||
}
|
||||
|
||||
s.chunks.PushBack(payload)
|
||||
s.size += len(payload)
|
||||
s.trimLocked()
|
||||
}
|
||||
|
||||
func (s *StarRing) trimLocked() {
|
||||
if s.size <= s.cap {
|
||||
return
|
||||
}
|
||||
overflow := s.size - s.cap
|
||||
for overflow > 0 {
|
||||
front := s.chunks.Front()
|
||||
if front == nil {
|
||||
s.size = 0
|
||||
return
|
||||
}
|
||||
head, ok := front.Value.([]byte)
|
||||
if !ok || len(head) == 0 {
|
||||
s.chunks.Remove(front)
|
||||
continue
|
||||
}
|
||||
if len(head) <= overflow {
|
||||
s.chunks.Remove(front)
|
||||
s.size -= len(head)
|
||||
overflow -= len(head)
|
||||
continue
|
||||
}
|
||||
trimmed := append([]byte(nil), head[overflow:]...)
|
||||
front.Value = trimmed
|
||||
s.size -= overflow
|
||||
overflow = 0
|
||||
}
|
||||
}
|
||||
104
starring_test.go
Normal file
104
starring_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewStarRingRejectsInvalidCapacity(t *testing.T) {
|
||||
ring, err := NewStarRing(0)
|
||||
if err != ErrStarRingInvalidCapacity {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ring != nil {
|
||||
t.Fatal("expected nil ring when capacity is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarRingWriteAndSnapshot(t *testing.T) {
|
||||
ring, err := NewStarRing(16)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n, err := ring.Write([]byte("abc")); err != nil || n != 3 {
|
||||
t.Fatalf("write failed: n=%d err=%v", n, err)
|
||||
}
|
||||
if n, err := ring.WriteString("def"); err != nil || n != 3 {
|
||||
t.Fatalf("write string failed: n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
got := ring.Snapshot()
|
||||
if !bytes.Equal(got, []byte("abcdef")) {
|
||||
t.Fatalf("unexpected snapshot: %q", got)
|
||||
}
|
||||
if ring.Len() != 6 {
|
||||
t.Fatalf("unexpected ring len: %d", ring.Len())
|
||||
}
|
||||
if ring.Capacity() != 16 {
|
||||
t.Fatalf("unexpected ring capacity: %d", ring.Capacity())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarRingOverflowDropsOldestBytesPrecisely(t *testing.T) {
|
||||
ring, err := NewStarRing(10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = ring.WriteString("abcde")
|
||||
_, _ = ring.WriteString("fghij")
|
||||
_, _ = ring.WriteString("kl")
|
||||
|
||||
got := ring.Snapshot()
|
||||
want := []byte("cdefghijkl")
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("unexpected snapshot after overflow: got %q want %q", got, want)
|
||||
}
|
||||
if ring.Len() != len(want) {
|
||||
t.Fatalf("unexpected ring len after overflow: %d", ring.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarRingLargeWriteKeepsTail(t *testing.T) {
|
||||
ring, err := NewStarRing(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = ring.WriteString("123")
|
||||
_, _ = ring.WriteString("456789")
|
||||
|
||||
got := ring.Snapshot()
|
||||
want := []byte("56789")
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("unexpected snapshot after large write: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarRingSnapshotIsCopy(t *testing.T) {
|
||||
ring, err := NewStarRing(8)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = ring.WriteString("hello")
|
||||
|
||||
first := ring.Snapshot()
|
||||
first[0] = 'H'
|
||||
second := ring.Snapshot()
|
||||
if !bytes.Equal(second, []byte("hello")) {
|
||||
t.Fatalf("snapshot should not expose internal storage: %q", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarRingReset(t *testing.T) {
|
||||
ring, err := NewStarRing(8)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = ring.WriteString("hello")
|
||||
ring.Reset()
|
||||
if ring.Len() != 0 {
|
||||
t.Fatalf("expected len=0 after reset, got %d", ring.Len())
|
||||
}
|
||||
if got := ring.Snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected empty snapshot after reset, got %q", got)
|
||||
}
|
||||
}
|
||||
17
sync.go
17
sync.go
@ -12,6 +12,10 @@ const (
|
||||
waitGroupAddModeLoose
|
||||
)
|
||||
|
||||
// WaitGroup is a concurrency-limited sync.WaitGroup variant.
|
||||
//
|
||||
// A zero or negative limit means unlimited concurrency. WaitGroup must not be
|
||||
// copied after first use.
|
||||
type WaitGroup struct {
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
@ -22,6 +26,7 @@ type WaitGroup struct {
|
||||
addMode waitGroupAddMode
|
||||
}
|
||||
|
||||
// NewWaitGroup creates a WaitGroup with the provided concurrency limit.
|
||||
func NewWaitGroup(maxCount int) WaitGroup {
|
||||
if maxCount < 0 {
|
||||
panic("stario: negative max wait count")
|
||||
@ -38,6 +43,10 @@ func (w *WaitGroup) init() {
|
||||
})
|
||||
}
|
||||
|
||||
// Add adjusts the running task count.
|
||||
//
|
||||
// Positive deltas may block when the concurrency limit is reached. Negative
|
||||
// deltas release running slots.
|
||||
func (w *WaitGroup) Add(delta int) {
|
||||
w.init()
|
||||
if delta == 0 {
|
||||
@ -87,10 +96,12 @@ func (w *WaitGroup) release(delta int) {
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
// Done releases one running task slot.
|
||||
func (w *WaitGroup) Done() {
|
||||
w.Add(-1)
|
||||
}
|
||||
|
||||
// Go runs fn in a goroutine while accounting for the concurrency limit.
|
||||
func (w *WaitGroup) Go(fn func()) {
|
||||
w.Add(1)
|
||||
go func() {
|
||||
@ -99,11 +110,13 @@ func (w *WaitGroup) Go(fn func()) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait blocks until all added work has completed.
|
||||
func (w *WaitGroup) Wait() {
|
||||
w.init()
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
// GetMaxWaitNum returns the current concurrency limit.
|
||||
func (w *WaitGroup) GetMaxWaitNum() int {
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
@ -111,6 +124,7 @@ func (w *WaitGroup) GetMaxWaitNum() int {
|
||||
return w.maxCount
|
||||
}
|
||||
|
||||
// SetMaxWaitNum updates the concurrency limit.
|
||||
func (w *WaitGroup) SetMaxWaitNum(num int) {
|
||||
if num < 0 {
|
||||
panic("stario: negative max wait count")
|
||||
@ -122,6 +136,8 @@ func (w *WaitGroup) SetMaxWaitNum(num int) {
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
// SetStrictAddMode controls whether Add(n>1) panics or auto-expands the limit
|
||||
// when the requested batch exceeds the current capacity.
|
||||
func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
@ -134,6 +150,7 @@ func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
// StrictAddMode reports whether strict batch-add behavior is enabled.
|
||||
func (w *WaitGroup) StrictAddMode() bool {
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
|
||||
73
terminal.go
Normal file
73
terminal.go
Normal file
@ -0,0 +1,73 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ErrTerminalNotTTY reports that terminal-only input was requested from a
|
||||
// non-terminal stdin.
|
||||
var ErrTerminalNotTTY = errors.New("terminal input requires a tty")
|
||||
|
||||
var terminalCheckFunc = term.IsTerminal
|
||||
|
||||
// IsTerminal reports whether os.Stdin is attached to a terminal.
|
||||
func IsTerminal() bool {
|
||||
return IsTerminalFile(os.Stdin)
|
||||
}
|
||||
|
||||
// IsTerminalFD reports whether fd is attached to a terminal.
|
||||
func IsTerminalFD(fd int) bool {
|
||||
return terminalCheckFunc(fd)
|
||||
}
|
||||
|
||||
// IsTerminalFile reports whether file is attached to a terminal.
|
||||
func IsTerminalFile(file *os.File) bool {
|
||||
if file == nil {
|
||||
return false
|
||||
}
|
||||
return IsTerminalFD(int(file.Fd()))
|
||||
}
|
||||
|
||||
// ReadPasswordContext reads one password-style line from the terminal.
|
||||
//
|
||||
// It returns ErrTerminalNotTTY when stdin is not a terminal. If ctx is canceled
|
||||
// while waiting for input, this function returns ctx.Err() without waiting for
|
||||
// the underlying terminal read to finish.
|
||||
func ReadPasswordContext(ctx context.Context, hint string, defaultVal string) (string, error) {
|
||||
return ReadPasswordContextWithMask(ctx, hint, defaultVal, "")
|
||||
}
|
||||
|
||||
// ReadPasswordContextWithMask is like ReadPasswordContext but echoes the given
|
||||
// mask string while the user types.
|
||||
func ReadPasswordContextWithMask(ctx context.Context, hint string, defaultVal string, mask string) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsTerminal() {
|
||||
return "", ErrTerminalNotTTY
|
||||
}
|
||||
session, err := rawTerminalSessionFactory(hint, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resultCh := make(chan InputMsg, 1)
|
||||
go func() {
|
||||
defer session.Close()
|
||||
resultCh <- rawLineInputSession(session, defaultVal, mask, rawInputSignalReturnError)
|
||||
}()
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.String()
|
||||
case <-ctx.Done():
|
||||
session.Abort()
|
||||
<-resultCh
|
||||
return "", ctx.Err()
|
||||
}
|
||||
}
|
||||
77
terminal_test.go
Normal file
77
terminal_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func installTerminalCheckStub(t *testing.T, ok bool) {
|
||||
t.Helper()
|
||||
prev := terminalCheckFunc
|
||||
terminalCheckFunc = func(fd int) bool {
|
||||
return ok
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
terminalCheckFunc = prev
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsTerminalFDUsesStub(t *testing.T) {
|
||||
installTerminalCheckStub(t, true)
|
||||
if !IsTerminalFD(1) {
|
||||
t.Fatal("expected terminal check to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPasswordContextRejectsNonTTY(t *testing.T) {
|
||||
installTerminalCheckStub(t, false)
|
||||
_, err := ReadPasswordContext(context.Background(), "pwd: ", "")
|
||||
if !errors.Is(err, ErrTerminalNotTTY) {
|
||||
t.Fatalf("expected ErrTerminalNotTTY, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPasswordContextReadsValue(t *testing.T) {
|
||||
installTerminalCheckStub(t, true)
|
||||
installRawInputStub(t, "secret\n", nil)
|
||||
|
||||
got, err := ReadPasswordContext(context.Background(), "pwd: ", "")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadPasswordContext returned error: %v", err)
|
||||
}
|
||||
if got != "secret" {
|
||||
t.Fatalf("unexpected password: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPasswordContextCanceledWhileWaiting(t *testing.T) {
|
||||
installTerminalCheckStub(t, true)
|
||||
|
||||
prevFactory := rawTerminalSessionFactory
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
|
||||
return &rawTerminalSession{
|
||||
input: pipeReader,
|
||||
reader: bufio.NewReader(pipeReader),
|
||||
redrawHint: strings.TrimSpace(hint),
|
||||
}, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rawTerminalSessionFactory = prevFactory
|
||||
_ = pipeWriter.Close()
|
||||
_ = pipeReader.Close()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := ReadPasswordContext(ctx, "pwd: ", "")
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
10
tty_other.go
Normal file
10
tty_other.go
Normal file
@ -0,0 +1,10 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package stario
|
||||
|
||||
import "os"
|
||||
|
||||
func openRawTerminalInput() (*os.File, error) {
|
||||
return os.OpenFile("/dev/tty", os.O_RDWR, 0)
|
||||
}
|
||||
10
tty_windows.go
Normal file
10
tty_windows.go
Normal file
@ -0,0 +1,10 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package stario
|
||||
|
||||
import "os"
|
||||
|
||||
func openRawTerminalInput() (*os.File, error) {
|
||||
return os.OpenFile("CONIN$", os.O_RDWR, 0)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user