stario: 提升 Go 1.20 基线与交互/队列稳定性
- 提升 go.mod 基线到 Go 1.20,并补齐对应测试 - 修正 Passwd / PasswdResponseSignal 语义,Ctrl+C 默认退出当前流程 - 优化 raw terminal redraw、Restore 与 StopUntil 的边界行为 - 新增 StarPipe、FrameReader/FrameWriter、ReadFullContext/WriteFullContext/CopyContext、IsTerminal/ReadPasswordContext - 收口 StarQueue / StarBuffer 语义,删除 EndWrite,统一 Close / Abort 行为 - 补齐 signal、timeout、queue、terminal、pipe、buffer 的回归测试与 race 覆盖
This commit is contained in:
parent
3add9183b3
commit
c8facb5a03
43
circle.go
43
circle.go
@ -10,6 +10,10 @@ var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be grea
|
|||||||
var ErrStarBufferClosed = errors.New("star buffer closed")
|
var ErrStarBufferClosed = errors.New("star buffer closed")
|
||||||
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
|
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
|
||||||
|
|
||||||
|
// StarBuffer is a blocking ring buffer that implements stream-style reads and writes.
|
||||||
|
//
|
||||||
|
// Close marks the write side finished after all payload bytes are sent.
|
||||||
|
// Abort aborts both sides immediately but still allows buffered bytes to be drained.
|
||||||
type StarBuffer struct {
|
type StarBuffer struct {
|
||||||
datas []byte
|
datas []byte
|
||||||
pStart uint64
|
pStart uint64
|
||||||
@ -36,18 +40,21 @@ func NewStarBuffer(cap uint64) (*StarBuffer, error) {
|
|||||||
return rtnBuffer, nil
|
return rtnBuffer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Free returns the remaining writable capacity.
|
||||||
func (star *StarBuffer) Free() uint64 {
|
func (star *StarBuffer) Free() uint64 {
|
||||||
star.mu.Lock()
|
star.mu.Lock()
|
||||||
defer star.mu.Unlock()
|
defer star.mu.Unlock()
|
||||||
return star.cap - star.size
|
return star.cap - star.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cap returns the fixed buffer capacity.
|
||||||
func (star *StarBuffer) Cap() uint64 {
|
func (star *StarBuffer) Cap() uint64 {
|
||||||
star.mu.Lock()
|
star.mu.Lock()
|
||||||
defer star.mu.Unlock()
|
defer star.mu.Unlock()
|
||||||
return star.cap
|
return star.cap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Len returns the currently buffered byte count.
|
||||||
func (star *StarBuffer) Len() uint64 {
|
func (star *StarBuffer) Len() uint64 {
|
||||||
star.mu.Lock()
|
star.mu.Lock()
|
||||||
defer star.mu.Unlock()
|
defer star.mu.Unlock()
|
||||||
@ -89,19 +96,21 @@ func (star *StarBuffer) putByte(data byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarBuffer) EndWrite() error {
|
// Close closes only the write side and satisfies the usual io.Closer-style
|
||||||
|
// "producer finished" semantics.
|
||||||
|
//
|
||||||
|
// Buffered bytes remain readable until drained; afterwards reads return io.EOF.
|
||||||
|
func (star *StarBuffer) Close() error {
|
||||||
star.mu.Lock()
|
star.mu.Lock()
|
||||||
defer star.mu.Unlock()
|
defer star.mu.Unlock()
|
||||||
if star.isClose {
|
return star.closeWriteLocked()
|
||||||
return ErrStarBufferClosed
|
|
||||||
}
|
|
||||||
star.isWriteEnd = true
|
|
||||||
star.notEmpty.Broadcast()
|
|
||||||
star.notFull.Broadcast()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarBuffer) Close() error {
|
// Abort aborts the buffer and wakes blocked readers/writers immediately.
|
||||||
|
//
|
||||||
|
// Buffered bytes remain readable until drained; subsequent writes fail with
|
||||||
|
// ErrStarBufferClosed.
|
||||||
|
func (star *StarBuffer) Abort() error {
|
||||||
star.mu.Lock()
|
star.mu.Lock()
|
||||||
defer star.mu.Unlock()
|
defer star.mu.Unlock()
|
||||||
star.isClose = true
|
star.isClose = true
|
||||||
@ -111,10 +120,17 @@ func (star *StarBuffer) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarBuffer) Read(buf []byte) (int, error) {
|
func (star *StarBuffer) closeWriteLocked() error {
|
||||||
if buf == nil {
|
if star.isClose {
|
||||||
return 0, errors.New("buffer is nil")
|
return ErrStarBufferClosed
|
||||||
}
|
}
|
||||||
|
star.isWriteEnd = true
|
||||||
|
star.notEmpty.Broadcast()
|
||||||
|
star.notFull.Broadcast()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (star *StarBuffer) Read(buf []byte) (int, error) {
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@ -140,9 +156,6 @@ func (star *StarBuffer) Read(buf []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (star *StarBuffer) Write(bts []byte) (int, error) {
|
func (star *StarBuffer) Write(bts []byte) (int, error) {
|
||||||
if bts == nil {
|
|
||||||
return 0, star.EndWrite()
|
|
||||||
}
|
|
||||||
if len(bts) == 0 {
|
if len(bts) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
44
circle_benchmark_test.go
Normal file
44
circle_benchmark_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkStarBufferByteRoundTrip(b *testing.B) {
|
||||||
|
buf, err := NewStarBuffer(4096)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(1)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := buf.putByte('a'); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := buf.getByte(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStarBufferChunkRoundTrip(b *testing.B) {
|
||||||
|
buf, err := NewStarBuffer(8192)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
payload := []byte("hello world b612 hello world b612 b612 b612 b612 b612 b612")
|
||||||
|
scratch := make([]byte, len(payload))
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payload)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := buf.Write(payload); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(buf, scratch); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
103
circle_test.go
103
circle_test.go
@ -2,11 +2,8 @@ package stario
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
||||||
@ -19,7 +16,7 @@ func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
func TestStarBufferCloseDrainsThenEOF(t *testing.T) {
|
||||||
buf, err := NewStarBuffer(4)
|
buf, err := NewStarBuffer(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -27,7 +24,7 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
|||||||
if _, err := buf.Write([]byte("abcd")); err != nil {
|
if _, err := buf.Write([]byte("abcd")); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := buf.EndWrite(); err != nil {
|
if err := buf.Close(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,11 +43,11 @@ func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
|
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
|
||||||
t.Fatalf("unexpected write error after EndWrite: %v", err)
|
t.Fatalf("unexpected write error after Close: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
func TestStarBufferAbortAllowsDrain(t *testing.T) {
|
||||||
buf, err := NewStarBuffer(4)
|
buf, err := NewStarBuffer(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -58,7 +55,7 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
|||||||
if _, err := buf.Write([]byte("ab")); err != nil {
|
if _, err := buf.Write([]byte("ab")); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := buf.Close(); err != nil {
|
if err := buf.Abort(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,102 +65,38 @@ func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
|
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
|
||||||
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n])
|
t.Fatalf("unexpected payload after abort: n=%d data=%q", n, got[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = buf.Read(got)
|
n, err = buf.Read(got)
|
||||||
if n != 0 || err != io.EOF {
|
if n != 0 || err != io.EOF {
|
||||||
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err)
|
t.Fatalf("expected EOF after draining aborted buffer, got n=%d err=%v", n, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
|
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
|
||||||
t.Fatalf("unexpected write error after Close: %v", err)
|
t.Fatalf("unexpected write error after Abort: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Circle(t *testing.T) {
|
func TestStarBufferNilReadIsNoOp(t *testing.T) {
|
||||||
buf, err := NewStarBuffer(2048)
|
buf, err := NewStarBuffer(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
go func() {
|
if n, err := buf.Read(nil); n != 0 || err != nil {
|
||||||
for {
|
t.Fatalf("expected nil read to be a no-op, got n=%d err=%v", n, err)
|
||||||
//fmt.Println("write start")
|
|
||||||
buf.Write([]byte("中华人民共和国\n"))
|
|
||||||
//fmt.Println("write success")
|
|
||||||
time.Sleep(time.Millisecond * 50)
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
cpp := ""
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
for {
|
|
||||||
cache := make([]byte, 64)
|
|
||||||
ints, err := buf.Read(cache)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("read error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ints != 0 {
|
|
||||||
cpp += string(cache[:ints])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Second * 13)
|
|
||||||
fmt.Println(cpp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Circle_Speed(t *testing.T) {
|
func TestStarBufferNilWriteIsNoOp(t *testing.T) {
|
||||||
buf, err := NewStarBuffer(1048976)
|
buf, err := NewStarBuffer(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
count := uint64(0)
|
if n, err := buf.Write(nil); n != 0 || err != nil {
|
||||||
for i := 1; i <= 10; i++ {
|
t.Fatalf("expected nil write to be a no-op, got n=%d err=%v", n, err)
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf.putByte('a')
|
|
||||||
}
|
}
|
||||||
}()
|
if _, err := buf.Write([]byte("ab")); err != nil {
|
||||||
}
|
t.Fatalf("nil write must not end the write side, got %v", err)
|
||||||
for i := 1; i <= 10; i++ {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
_, err := buf.getByte()
|
|
||||||
if err == nil {
|
|
||||||
atomic.AddUint64(&count, 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 10)
|
|
||||||
fmt.Println(count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Circle_Speed2(t *testing.T) {
|
|
||||||
buf, err := NewStarBuffer(8192)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
count := uint64(0)
|
|
||||||
for i := 1; i <= 10; i++ {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf.Write([]byte("hello world b612 hello world b612 b612 b612 b612 b612 b612"))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
for i := 1; i <= 10; i++ {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
mybuf := make([]byte, 1024)
|
|
||||||
j, err := buf.Read(mybuf)
|
|
||||||
if err == nil {
|
|
||||||
atomic.AddUint64(&count, uint64(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 10)
|
|
||||||
fmt.Println(float64(count) / 10 / 1024 / 1024)
|
|
||||||
}
|
|
||||||
|
|||||||
126
fn.go
126
fn.go
@ -1,55 +1,111 @@
|
|||||||
package stario
|
package stario
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ERR_TIMEOUT is the legacy timeout sentinel used by WaitUntilTimeout*.
|
||||||
var ERR_TIMEOUT = errors.New("TIME OUT")
|
var ERR_TIMEOUT = errors.New("TIME OUT")
|
||||||
|
|
||||||
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error {
|
// WaitUntilContext runs fn and returns either its result or the context error,
|
||||||
var err error
|
// whichever happens first.
|
||||||
finished := make(chan struct{})
|
func WaitUntilContext(ctx context.Context, fn func(context.Context) error) error {
|
||||||
imout := make(chan struct{})
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
finished := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err = fn(imout)
|
finished <- fn(ctx)
|
||||||
finished <- struct{}{}
|
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-finished:
|
case err := <-finished:
|
||||||
return err
|
return err
|
||||||
case <-time.After(tm):
|
case <-ctx.Done():
|
||||||
close(imout)
|
return ctx.Err()
|
||||||
return ERR_TIMEOUT
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitUntilFinished(fn func() error) <-chan error {
|
// WaitUntilContextFinished is the asynchronous form of WaitUntilContext.
|
||||||
finished := make(chan error)
|
func WaitUntilContextFinished(ctx context.Context, fn func(context.Context) error) <-chan error {
|
||||||
|
result := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := fn()
|
result <- WaitUntilContext(ctx, fn)
|
||||||
finished <- err
|
close(result)
|
||||||
}()
|
|
||||||
return finished
|
|
||||||
}
|
|
||||||
|
|
||||||
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
|
|
||||||
var err error
|
|
||||||
finished := make(chan struct{})
|
|
||||||
result := make(chan error)
|
|
||||||
imout := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err = fn(imout)
|
|
||||||
finished <- struct{}{}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-finished:
|
|
||||||
result <- err
|
|
||||||
case <-time.After(tm):
|
|
||||||
close(imout)
|
|
||||||
result <- ERR_TIMEOUT
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WaitUntilContextDone adapts a done-channel worker to a context-based wait.
|
||||||
|
func WaitUntilContextDone(ctx context.Context, fn func(<-chan struct{}) error) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return WaitUntilContext(ctx, func(context.Context) error {
|
||||||
|
return fn(ctx.Done())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitUntilContextDoneFinished is the asynchronous form of WaitUntilContextDone.
|
||||||
|
func WaitUntilContextDoneFinished(ctx context.Context, fn func(<-chan struct{}) error) <-chan error {
|
||||||
|
return WaitUntilContextFinished(ctx, func(ctx context.Context) error {
|
||||||
|
return fn(ctx.Done())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitUntilTimeout is a legacy timeout helper kept for compatibility.
|
||||||
|
//
|
||||||
|
// The provided stop channel must be treated as receive-only by callers. New
|
||||||
|
// code should prefer WaitUntilContext or WaitUntilContextDone.
|
||||||
|
func WaitUntilTimeout(tm time.Duration, fn func(chan struct{}) error) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), tm)
|
||||||
|
defer cancel()
|
||||||
|
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||||
|
return fn(bridgeDoneChan(done))
|
||||||
|
})
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return ERR_TIMEOUT
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitUntilFinished runs fn asynchronously and returns its eventual result.
|
||||||
|
func WaitUntilFinished(fn func() error) <-chan error {
|
||||||
|
return WaitUntilContextFinished(context.Background(), func(ctx context.Context) error {
|
||||||
|
return fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitUntilTimeoutFinished is the asynchronous form of WaitUntilTimeout.
|
||||||
|
//
|
||||||
|
// The provided stop channel must be treated as receive-only by callers. New
|
||||||
|
// code should prefer WaitUntilContextFinished or WaitUntilContextDoneFinished.
|
||||||
|
func WaitUntilTimeoutFinished(tm time.Duration, fn func(chan struct{}) error) <-chan error {
|
||||||
|
result := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), tm)
|
||||||
|
defer cancel()
|
||||||
|
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||||
|
return fn(bridgeDoneChan(done))
|
||||||
|
})
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
err = ERR_TIMEOUT
|
||||||
|
}
|
||||||
|
result <- err
|
||||||
|
close(result)
|
||||||
|
}()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func bridgeDoneChan(done <-chan struct{}) chan struct{} {
|
||||||
|
stop := make(chan struct{})
|
||||||
|
if done == nil {
|
||||||
|
return stop
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-done
|
||||||
|
close(stop)
|
||||||
|
}()
|
||||||
|
return stop
|
||||||
|
}
|
||||||
|
|||||||
174
fn_test.go
Normal file
174
fn_test.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWaitUntilContextReturnsWorkerError(t *testing.T) {
|
||||||
|
want := errors.New("worker failed")
|
||||||
|
err := WaitUntilContext(context.Background(), func(ctx context.Context) error {
|
||||||
|
return want
|
||||||
|
})
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilContextReturnsDeadlineExceeded(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
workerReturned := make(chan struct{})
|
||||||
|
err := WaitUntilContext(ctx, func(ctx context.Context) error {
|
||||||
|
<-ctx.Done()
|
||||||
|
close(workerReturned)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("unexpected context error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-workerReturned:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("worker did not return after context deadline")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilContextFinishedReturnsCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
result := WaitUntilContextFinished(ctx, func(ctx context.Context) error {
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err, ok := <-result
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result channel closed without value")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("unexpected context error: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := <-result; ok {
|
||||||
|
t.Fatal("result channel should be closed after delivering the outcome")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilContextDoneBridgesDoneChannel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
doneClosed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
err := WaitUntilContextDone(ctx, func(done <-chan struct{}) error {
|
||||||
|
<-done
|
||||||
|
close(doneClosed)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("unexpected context error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-doneClosed:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("done channel was not closed when context canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilContextDoneFinishedReturnsCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
doneClosed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
result := WaitUntilContextDoneFinished(ctx, func(done <-chan struct{}) error {
|
||||||
|
<-done
|
||||||
|
close(doneClosed)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err, ok := <-result
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result channel closed without value")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("unexpected context error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-doneClosed:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("done channel was not closed when context canceled")
|
||||||
|
}
|
||||||
|
if _, ok := <-result; ok {
|
||||||
|
t.Fatal("result channel should be closed after delivering the outcome")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilTimeoutReturnsWorkerError(t *testing.T) {
|
||||||
|
want := errors.New("worker failed")
|
||||||
|
err := WaitUntilTimeout(time.Second, func(stop chan struct{}) error {
|
||||||
|
return want
|
||||||
|
})
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilTimeoutTimesOutWithoutBlockingWorkerReturn(t *testing.T) {
|
||||||
|
workerReturned := make(chan struct{})
|
||||||
|
err := WaitUntilTimeout(20*time.Millisecond, func(stop chan struct{}) error {
|
||||||
|
<-stop
|
||||||
|
close(workerReturned)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ERR_TIMEOUT) {
|
||||||
|
t.Fatalf("unexpected timeout error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-workerReturned:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("worker did not return after timeout signal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilFinishedReturnsWorkerError(t *testing.T) {
|
||||||
|
want := errors.New("worker failed")
|
||||||
|
err := <-WaitUntilFinished(func() error {
|
||||||
|
return want
|
||||||
|
})
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilTimeoutFinishedReturnsTimeout(t *testing.T) {
|
||||||
|
result := WaitUntilTimeoutFinished(20*time.Millisecond, func(stop chan struct{}) error {
|
||||||
|
<-stop
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err, ok := <-result
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result channel closed without value")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ERR_TIMEOUT) {
|
||||||
|
t.Fatalf("unexpected timeout error: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := <-result; ok {
|
||||||
|
t.Fatal("result channel should be closed after delivering the outcome")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilTimeoutFinishedReturnsWorkerError(t *testing.T) {
|
||||||
|
want := errors.New("worker failed")
|
||||||
|
err, ok := <-WaitUntilTimeoutFinished(time.Second, func(stop chan struct{}) error {
|
||||||
|
return want
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("result channel closed without value")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("unexpected error: got %v want %v", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
191
frameio.go
Normal file
191
frameio.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
|
||||||
|
// FrameReader.
|
||||||
|
const DefaultFrameReaderBufferSize = 32 * 1024
|
||||||
|
|
||||||
|
type frameReaderConnKey struct{}
|
||||||
|
|
||||||
|
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
|
||||||
|
type FrameWriter struct {
|
||||||
|
writer io.Writer
|
||||||
|
queue *StarQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrameWriter creates a framing writer backed by queue. When queue is nil, a
|
||||||
|
// default StarQueue is created.
|
||||||
|
func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter {
|
||||||
|
if queue == nil {
|
||||||
|
queue = NewQueue()
|
||||||
|
}
|
||||||
|
return &FrameWriter{
|
||||||
|
writer: writer,
|
||||||
|
queue: queue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFrame writes one framed payload.
|
||||||
|
func (writer *FrameWriter) WriteFrame(payload []byte) error {
|
||||||
|
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.queue.WriteFrame(writer.writer, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFrameBuffers writes one framed payload using net.Buffers when possible.
|
||||||
|
func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error {
|
||||||
|
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.queue.WriteFrameBuffers(writer.writer, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFramesBuffers writes multiple framed payloads in one batch when
|
||||||
|
// possible.
|
||||||
|
func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
|
||||||
|
if writer == nil || writer.writer == nil || writer.queue == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.queue.WriteFramesBuffers(writer.writer, payloads...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
|
||||||
|
type FrameReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
queue *StarQueue
|
||||||
|
connKey interface{}
|
||||||
|
readSize int
|
||||||
|
pending [][]byte
|
||||||
|
pendingErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
|
||||||
|
// default StarQueue is created.
|
||||||
|
func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader {
|
||||||
|
if queue == nil {
|
||||||
|
queue = NewQueue()
|
||||||
|
}
|
||||||
|
return &FrameReader{
|
||||||
|
reader: reader,
|
||||||
|
queue: queue,
|
||||||
|
connKey: &frameReaderConnKey{},
|
||||||
|
readSize: DefaultFrameReaderBufferSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadBufferSize updates the underlying transport read chunk size.
|
||||||
|
func (reader *FrameReader) SetReadBufferSize(size int) {
|
||||||
|
if reader == nil || size <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reader.readSize = size
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConnKey overrides the internal queue connection key.
|
||||||
|
func (reader *FrameReader) SetConnKey(conn interface{}) error {
|
||||||
|
if reader == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
reader.connKey = &frameReaderConnKey{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := validateConnKey(conn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reader.connKey = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns the next framed payload.
|
||||||
|
func (reader *FrameReader) Next() ([]byte, error) {
|
||||||
|
if reader == nil || reader.reader == nil || reader.queue == nil {
|
||||||
|
return nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if len(reader.pending) > 0 {
|
||||||
|
return reader.popPending(), nil
|
||||||
|
}
|
||||||
|
if reader.pendingErr != nil {
|
||||||
|
err := reader.pendingErr
|
||||||
|
reader.pendingErr = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if reader.readSize <= 0 {
|
||||||
|
reader.readSize = DefaultFrameReaderBufferSize
|
||||||
|
}
|
||||||
|
buf := make([]byte, reader.readSize)
|
||||||
|
for {
|
||||||
|
n, readErr := reader.reader.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
|
||||||
|
reader.pending = append(reader.pending, msg.Msg)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err := reader.normalizeReadErr(parseErr, readErr)
|
||||||
|
if len(reader.pending) > 0 {
|
||||||
|
reader.pendingErr = err
|
||||||
|
return reader.popPending(), nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, reader.normalizeReadErr(nil, readErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) popPending() []byte {
|
||||||
|
next := reader.pending[0]
|
||||||
|
reader.pending = reader.pending[1:]
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinFrameReaderError(parseErr error, readErr error) error {
|
||||||
|
switch {
|
||||||
|
case parseErr == nil:
|
||||||
|
return readErr
|
||||||
|
case readErr == nil:
|
||||||
|
return parseErr
|
||||||
|
default:
|
||||||
|
return errors.Join(parseErr, readErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error {
|
||||||
|
if errors.Is(readErr, io.EOF) && reader.hasBufferedState() {
|
||||||
|
reader.clearBufferedState()
|
||||||
|
readErr = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return joinFrameReaderError(parseErr, readErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) hasBufferedState() bool {
|
||||||
|
if reader == nil || reader.queue == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
stateAny, ok := reader.queue.states.Load(reader.connKey)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state, ok := stateAny.(*queConnState)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state.mu.Lock()
|
||||||
|
defer state.mu.Unlock()
|
||||||
|
return len(state.buf) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) clearBufferedState() {
|
||||||
|
if reader == nil || reader.queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reader.queue.states.Delete(reader.connKey)
|
||||||
|
}
|
||||||
105
frameio_test.go
Normal file
105
frameio_test.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chunkedReader struct {
|
||||||
|
data []byte
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *chunkedReader) Read(p []byte) (int, error) {
|
||||||
|
if len(reader.data) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if reader.max > 0 && len(p) > reader.max {
|
||||||
|
p = p[:reader.max]
|
||||||
|
}
|
||||||
|
n := copy(p, reader.data)
|
||||||
|
reader.data = reader.data[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameWriterReaderRoundTrip(t *testing.T) {
|
||||||
|
var wire bytes.Buffer
|
||||||
|
writer := NewFrameWriter(&wire, nil)
|
||||||
|
if err := writer.WriteFrame([]byte("one")); err != nil {
|
||||||
|
t.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil {
|
||||||
|
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil)
|
||||||
|
cases := []string{"one", "two", "three"}
|
||||||
|
for _, want := range cases {
|
||||||
|
got, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Next returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != want {
|
||||||
|
t.Fatalf("unexpected frame: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := reader.Next(); !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("expected EOF after draining frames, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
|
||||||
|
var wire bytes.Buffer
|
||||||
|
writer := NewFrameWriter(&wire, nil)
|
||||||
|
if err := writer.WriteFrame([]byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil)
|
||||||
|
reader.SetReadBufferSize(3)
|
||||||
|
got, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Next returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != "hello" {
|
||||||
|
t.Fatalf("unexpected frame: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameWriterNilWriterFails(t *testing.T) {
|
||||||
|
writer := NewFrameWriter(nil, nil)
|
||||||
|
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) {
|
||||||
|
queue := NewQueue()
|
||||||
|
frame := queue.BuildMessage([]byte("hello"))
|
||||||
|
reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil)
|
||||||
|
|
||||||
|
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) {
|
||||||
|
queue := NewQueue()
|
||||||
|
first := queue.BuildMessage([]byte("one"))
|
||||||
|
second := queue.BuildMessage([]byte("two"))
|
||||||
|
wire := append(append([]byte{}, first...), second[:len(second)-1]...)
|
||||||
|
reader := NewFrameReader(bytes.NewReader(wire), nil)
|
||||||
|
|
||||||
|
got, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on first frame: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != "one" {
|
||||||
|
t.Fatalf("unexpected first frame: %q", got)
|
||||||
|
}
|
||||||
|
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
6
go.mod
6
go.mod
@ -1,5 +1,7 @@
|
|||||||
module b612.me/stario
|
module b612.me/stario
|
||||||
|
|
||||||
go 1.16
|
go 1.20
|
||||||
|
|
||||||
require golang.org/x/crypto v0.26.0
|
require golang.org/x/term v0.23.0
|
||||||
|
|
||||||
|
require golang.org/x/sys v0.23.0 // indirect
|
||||||
|
|||||||
63
go.sum
63
go.sum
@ -1,67 +1,4 @@
|
|||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
|
||||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
|
||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
|
||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
|
||||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
|
||||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
|
||||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
|
||||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
|
||||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
|
||||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
|
||||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
|
||||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
|
||||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|
||||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
|
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
|
||||||
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
|
||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
|
||||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
|
||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
|
||||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
|
||||||
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
|
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
|
||||||
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
|
||||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|
||||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
|
||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
|
||||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
|
|||||||
235
io.go
235
io.go
@ -3,11 +3,12 @@ package stario
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/term"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type InputMsg struct {
|
type InputMsg struct {
|
||||||
@ -16,30 +17,54 @@ type InputMsg struct {
|
|||||||
skipSliceSigErr bool
|
skipSliceSigErr bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type rawInputSignalMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
rawInputSignalIgnore rawInputSignalMode = iota
|
||||||
|
rawInputSignalExit
|
||||||
|
rawInputSignalReturnError
|
||||||
|
)
|
||||||
|
|
||||||
type rawTerminalSession struct {
|
type rawTerminalSession struct {
|
||||||
fd int
|
fd int
|
||||||
state *terminal.State
|
state *term.State
|
||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
|
input io.Closer
|
||||||
redrawHint string
|
redrawHint string
|
||||||
printNewline bool
|
printNewline bool
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rawTerminalSessionFactory = newRawTerminalSession
|
||||||
|
var inputSignalHandler = signal
|
||||||
|
|
||||||
|
// Passwd reads one password-style line in raw mode.
|
||||||
|
//
|
||||||
|
// When the user presses an input signal such as Ctrl+C, this compatibility
|
||||||
|
// entry exits the current flow and returns an empty message with a nil error.
|
||||||
func Passwd(hint string, defaultVal string) InputMsg {
|
func Passwd(hint string, defaultVal string) InputMsg {
|
||||||
return passwd(hint, defaultVal, "", true)
|
return passwd(hint, defaultVal, "", rawInputSignalExit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswdWithMask is like Passwd but echoes the provided mask string.
|
||||||
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
|
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
|
||||||
return passwd(hint, defaultVal, mask, true)
|
return passwd(hint, defaultVal, mask, rawInputSignalExit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswdResponseSignal is like Passwd but preserves input-signal errors for
|
||||||
|
// callers that need to distinguish Ctrl+C / Ctrl+Z style exits.
|
||||||
func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
|
func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
|
||||||
return passwd(hint, defaultVal, "", true)
|
return passwd(hint, defaultVal, "", rawInputSignalReturnError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswdResponseSignalWithMask is like PasswdResponseSignal but echoes the
|
||||||
|
// provided mask string.
|
||||||
func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg {
|
func PasswdResponseSignalWithMask(hint string, defaultVal string, mask string) InputMsg {
|
||||||
return passwd(hint, defaultVal, mask, true)
|
return passwd(hint, defaultVal, mask, rawInputSignalReturnError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MessageBoxRaw reads one line in raw mode without treating control keys as
|
||||||
|
// exit signals.
|
||||||
func MessageBoxRaw(hint string, defaultVal string) InputMsg {
|
func MessageBoxRaw(hint string, defaultVal string) InputMsg {
|
||||||
return messageBox(hint, defaultVal)
|
return messageBox(hint, defaultVal)
|
||||||
}
|
}
|
||||||
@ -48,35 +73,64 @@ func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession,
|
|||||||
if hint != "" {
|
if hint != "" {
|
||||||
fmt.Print(hint)
|
fmt.Print(hint)
|
||||||
}
|
}
|
||||||
fd := int(os.Stdin.Fd())
|
input, err := openRawTerminalInput()
|
||||||
state, err := terminal.MakeRaw(fd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
fd := int(input.Fd())
|
||||||
|
state, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = input.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &rawTerminalSession{
|
return &rawTerminalSession{
|
||||||
fd: fd,
|
fd: fd,
|
||||||
state: state,
|
state: state,
|
||||||
reader: bufio.NewReader(os.Stdin),
|
reader: bufio.NewReader(input),
|
||||||
|
input: input,
|
||||||
redrawHint: promptRedrawHint(hint),
|
redrawHint: promptRedrawHint(hint),
|
||||||
printNewline: printNewline,
|
printNewline: printNewline,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *rawTerminalSession) Close() {
|
func (session *rawTerminalSession) Close() {
|
||||||
if session == nil || session.state == nil {
|
if session == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_ = terminal.Restore(session.fd, session.state)
|
session.mu.Lock()
|
||||||
|
defer session.mu.Unlock()
|
||||||
|
if session.state != nil {
|
||||||
|
_ = term.Restore(session.fd, session.state)
|
||||||
|
session.state = nil
|
||||||
|
}
|
||||||
if session.printNewline {
|
if session.printNewline {
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
session.printNewline = false
|
||||||
|
}
|
||||||
|
if session.input != nil {
|
||||||
|
_ = session.input.Close()
|
||||||
|
session.input = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *rawTerminalSession) Restore() error {
|
func (session *rawTerminalSession) Restore() error {
|
||||||
if session == nil || session.state == nil {
|
if session == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return terminal.Restore(session.fd, session.state)
|
session.mu.Lock()
|
||||||
|
defer session.mu.Unlock()
|
||||||
|
if session.state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := term.Restore(session.fd, session.state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
session.state = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *rawTerminalSession) Abort() {
|
||||||
|
session.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptRedrawHint(hint string) string {
|
func promptRedrawHint(hint string) string {
|
||||||
@ -101,6 +155,20 @@ func renderRawEcho(ioBuf []rune, mask string) string {
|
|||||||
return strings.Repeat(mask, len(ioBuf))
|
return strings.Repeat(mask, len(ioBuf))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rawEchoRenderUnit(r rune, mask string, maskWidth int) (string, int, bool) {
|
||||||
|
if mask != "" {
|
||||||
|
if maskWidth <= 0 {
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
return mask, maskWidth, true
|
||||||
|
}
|
||||||
|
width := runeDisplayWidth(r)
|
||||||
|
if width <= 0 {
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
return string(r), width, true
|
||||||
|
}
|
||||||
|
|
||||||
func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
||||||
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
|
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
|
||||||
clearWidth := lastWidth
|
clearWidth := lastWidth
|
||||||
@ -121,6 +189,22 @@ func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
|||||||
return nowWidth
|
return nowWidth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func erasePromptTail(width int) {
|
||||||
|
if width <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
backtrack := strings.Repeat("\b", width)
|
||||||
|
fmt.Print(backtrack)
|
||||||
|
fmt.Print(strings.Repeat(" ", width))
|
||||||
|
fmt.Print(backtrack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redrawPromptEcho(hint string, ioBuf []rune, mask string, lastWidth int) (int, int) {
|
||||||
|
echo := renderRawEcho(ioBuf, mask)
|
||||||
|
echoWidth := stringDisplayWidth(echo)
|
||||||
|
return echoWidth, redrawPromptLine(hint, echo, lastWidth)
|
||||||
|
}
|
||||||
|
|
||||||
func stringDisplayWidth(text string) int {
|
func stringDisplayWidth(text string) int {
|
||||||
width := 0
|
width := 0
|
||||||
for _, r := range text {
|
for _, r := range text {
|
||||||
@ -157,46 +241,79 @@ func isWideRune(r rune) bool {
|
|||||||
(r >= 0x20000 && r <= 0x3fffd))
|
(r >= 0x20000 && r <= 0x3fffd))
|
||||||
}
|
}
|
||||||
|
|
||||||
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
func signalInputResult(mode rawInputSignalMode, err error) InputMsg {
|
||||||
session, err := newRawTerminalSession(hint, true)
|
switch mode {
|
||||||
|
case rawInputSignalExit:
|
||||||
|
return InputMsg{msg: "", err: nil}
|
||||||
|
case rawInputSignalReturnError:
|
||||||
|
return InputMsg{msg: "", err: err}
|
||||||
|
default:
|
||||||
|
return InputMsg{msg: "", err: nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawLineInput(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||||
|
session, err := rawTerminalSessionFactory(hint, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return InputMsg{msg: "", err: err}
|
return InputMsg{msg: "", err: err}
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
return rawLineInputSession(session, defaultVal, mask, signalMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawLineInputSession(session *rawTerminalSession, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||||
|
if session == nil || session.reader == nil {
|
||||||
|
return InputMsg{msg: "", err: io.ErrClosedPipe}
|
||||||
|
}
|
||||||
ioBuf := make([]rune, 0, 16)
|
ioBuf := make([]rune, 0, 16)
|
||||||
lastWidth := 0
|
promptWidth := stringDisplayWidth(session.redrawHint)
|
||||||
|
maskWidth := stringDisplayWidth(mask)
|
||||||
|
echoWidth := 0
|
||||||
|
lastWidth := promptWidth
|
||||||
for {
|
for {
|
||||||
b, _, err := session.reader.ReadRune()
|
b, _, err := session.reader.ReadRune()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return InputMsg{msg: "", err: err}
|
return InputMsg{msg: "", err: err}
|
||||||
}
|
}
|
||||||
if handleSignal && isSignal(b) {
|
if signalMode != rawInputSignalIgnore && isSignal(b) {
|
||||||
if runtime.GOOS != "windows" {
|
session.Close()
|
||||||
if err := session.Restore(); err != nil {
|
if signalMode == rawInputSignalExit {
|
||||||
return InputMsg{msg: "", err: err}
|
return signalInputResult(signalMode, nil)
|
||||||
}
|
}
|
||||||
}
|
return signalInputResult(signalMode, inputSignalHandler(b))
|
||||||
if err := signal(b); err != nil {
|
|
||||||
return InputMsg{msg: "", err: err}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
switch b {
|
switch b {
|
||||||
case 0x0d, 0x0a:
|
case 0x0d, 0x0a:
|
||||||
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
|
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
|
||||||
case 0x08, 0x7F:
|
case 0x08, 0x7F:
|
||||||
if len(ioBuf) > 0 {
|
if len(ioBuf) > 0 {
|
||||||
|
removed := ioBuf[len(ioBuf)-1]
|
||||||
ioBuf = ioBuf[:len(ioBuf)-1]
|
ioBuf = ioBuf[:len(ioBuf)-1]
|
||||||
|
if _, removedWidth, ok := rawEchoRenderUnit(removed, mask, maskWidth); ok {
|
||||||
|
erasePromptTail(removedWidth)
|
||||||
|
echoWidth -= removedWidth
|
||||||
|
if echoWidth < 0 {
|
||||||
|
echoWidth = 0
|
||||||
|
}
|
||||||
|
lastWidth = promptWidth + echoWidth
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
ioBuf = append(ioBuf, b)
|
ioBuf = append(ioBuf, b)
|
||||||
|
if appendText, appendWidth, ok := rawEchoRenderUnit(b, mask, maskWidth); ok {
|
||||||
|
fmt.Print(appendText)
|
||||||
|
echoWidth += appendWidth
|
||||||
|
lastWidth = promptWidth + echoWidth
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth)
|
}
|
||||||
|
echoWidth, lastWidth = redrawPromptEcho(session.redrawHint, ioBuf, mask, lastWidth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func messageBox(hint string, defaultVal string) InputMsg {
|
func messageBox(hint string, defaultVal string) InputMsg {
|
||||||
return rawLineInput(hint, defaultVal, "", false)
|
return rawLineInput(hint, defaultVal, "", rawInputSignalIgnore)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSignal(s rune) bool {
|
func isSignal(s rune) bool {
|
||||||
@ -208,10 +325,12 @@ func isSignal(s rune) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
func passwd(hint string, defaultVal string, mask string, signalMode rawInputSignalMode) InputMsg {
|
||||||
return rawLineInput(hint, defaultVal, mask, handleSignal)
|
return rawLineInput(hint, defaultVal, mask, signalMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MessageBox reads one line in cooked mode and falls back to defaultVal when
|
||||||
|
// the trimmed input is empty.
|
||||||
func MessageBox(hint string, defaultVal string) InputMsg {
|
func MessageBox(hint string, defaultVal string) InputMsg {
|
||||||
if hint != "" {
|
if hint != "" {
|
||||||
fmt.Print(hint)
|
fmt.Print(hint)
|
||||||
@ -264,7 +383,7 @@ func (im InputMsg) sliceFn(sep string, fn func(string) (interface{}, error)) ([]
|
|||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
for _, v := range data {
|
for _, v := range data {
|
||||||
code, err := fn(v)
|
code, err := fn(strings.TrimSpace(v))
|
||||||
if err != nil && !im.skipSliceSigErr {
|
if err != nil && !im.skipSliceSigErr {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
@ -428,7 +547,8 @@ func (im InputMsg) MustFloat32() float32 {
|
|||||||
|
|
||||||
func (im InputMsg) SliceFloat32(sep string) ([]float32, error) {
|
func (im InputMsg) SliceFloat32(sep string) ([]float32, error) {
|
||||||
data, err := im.sliceFn(sep, func(v string) (interface{}, error) {
|
data, err := im.sliceFn(sep, func(v string) (interface{}, error) {
|
||||||
return strconv.ParseFloat(v, 32)
|
f, err := strconv.ParseFloat(v, 32)
|
||||||
|
return float32(f), err
|
||||||
})
|
})
|
||||||
var res []float32
|
var res []float32
|
||||||
for _, v := range data {
|
for _, v := range data {
|
||||||
@ -477,13 +597,47 @@ func YesNoE(hint string, defaults bool) (bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildTriggerPrefixTable(triggerRunes []rune) []int {
|
||||||
|
if len(triggerRunes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
prefix := make([]int, len(triggerRunes))
|
||||||
|
for i := 1; i < len(triggerRunes); i++ {
|
||||||
|
j := prefix[i-1]
|
||||||
|
for j > 0 && triggerRunes[i] != triggerRunes[j] {
|
||||||
|
j = prefix[j-1]
|
||||||
|
}
|
||||||
|
if triggerRunes[i] == triggerRunes[j] {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
prefix[i] = j
|
||||||
|
}
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func advanceTriggerIndex(triggerRunes []rune, prefix []int, current int, input rune) (int, bool) {
|
||||||
|
if len(triggerRunes) == 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
for current > 0 && input != triggerRunes[current] {
|
||||||
|
current = prefix[current-1]
|
||||||
|
}
|
||||||
|
if input == triggerRunes[current] {
|
||||||
|
current++
|
||||||
|
if current == len(triggerRunes) {
|
||||||
|
return current, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return current, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopUntil keeps reading raw input until trigger is matched.
|
||||||
|
// When trigger == "", it returns after the first key press, which is used for
|
||||||
|
// "press any key to continue" style prompts.
|
||||||
func StopUntil(hint string, trigger string, repeat bool) error {
|
func StopUntil(hint string, trigger string, repeat bool) error {
|
||||||
triggerRunes := []rune(trigger)
|
triggerRunes := []rune(trigger)
|
||||||
pressLen := len(triggerRunes)
|
prefix := buildTriggerPrefixTable(triggerRunes)
|
||||||
if trigger == "" {
|
session, err := rawTerminalSessionFactory(hint, false)
|
||||||
pressLen = 1
|
|
||||||
}
|
|
||||||
session, err := newRawTerminalSession(hint, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -497,11 +651,12 @@ func StopUntil(hint string, trigger string, repeat bool) error {
|
|||||||
if trigger == "" {
|
if trigger == "" {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if b == triggerRunes[i] {
|
next, complete := advanceTriggerIndex(triggerRunes, prefix, i, b)
|
||||||
i++
|
if complete {
|
||||||
if i == pressLen {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if next > 0 {
|
||||||
|
i = next
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
i = 0
|
i = 0
|
||||||
|
|||||||
81
io_benchmark_test.go
Normal file
81
io_benchmark_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func benchmarkRawRedrawAppend(b *testing.B, current []rune, next rune, mask string) {
|
||||||
|
b.Run("legacy", func(b *testing.B) {
|
||||||
|
ioBuf := make([]rune, len(current)+1)
|
||||||
|
copy(ioBuf, current)
|
||||||
|
ioBuf[len(current)] = next
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
echo := renderRawEcho(ioBuf, mask)
|
||||||
|
_ = stringDisplayWidth(echo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("fast", func(b *testing.B) {
|
||||||
|
maskWidth := stringDisplayWidth(mask)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, ok := rawEchoRenderUnit(next, mask, maskWidth)
|
||||||
|
if !ok {
|
||||||
|
b.Fatal("fast path rejected append rune")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkRawRedrawBackspace(b *testing.B, current []rune, mask string) {
|
||||||
|
if len(current) == 0 {
|
||||||
|
b.Fatal("backspace benchmark requires non-empty input")
|
||||||
|
}
|
||||||
|
last := current[len(current)-1]
|
||||||
|
trimmed := current[:len(current)-1]
|
||||||
|
|
||||||
|
b.Run("legacy", func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
echo := renderRawEcho(trimmed, mask)
|
||||||
|
_ = stringDisplayWidth(echo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("fast", func(b *testing.B) {
|
||||||
|
maskWidth := stringDisplayWidth(mask)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, ok := rawEchoRenderUnit(last, mask, maskWidth)
|
||||||
|
if !ok {
|
||||||
|
b.Fatal("fast path rejected backspace rune")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRawRedrawAppendPlainLong(b *testing.B) {
|
||||||
|
current := []rune(strings.Repeat("a", 4096))
|
||||||
|
benchmarkRawRedrawAppend(b, current, 'b', "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRawRedrawAppendMaskLong(b *testing.B) {
|
||||||
|
current := []rune(strings.Repeat("a", 4096))
|
||||||
|
benchmarkRawRedrawAppend(b, current, 'b', "[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRawRedrawBackspacePlainLong(b *testing.B) {
|
||||||
|
current := []rune(strings.Repeat("a", 4096))
|
||||||
|
benchmarkRawRedrawBackspace(b, current, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRawRedrawBackspaceMaskLong(b *testing.B) {
|
||||||
|
current := []rune(strings.Repeat("a", 4096))
|
||||||
|
benchmarkRawRedrawBackspace(b, current, "[]")
|
||||||
|
}
|
||||||
177
io_context.go
Normal file
177
io_context.go
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultCopyContextBufferSize = 32 * 1024
|
||||||
|
|
||||||
|
type readContextResult struct {
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeContextResult struct {
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFullContext reads exactly len(buf) bytes unless the context is canceled
|
||||||
|
// or the underlying reader returns an error.
|
||||||
|
//
|
||||||
|
// If ctx is canceled while the underlying Read call is blocked, ReadFullContext
|
||||||
|
// returns ctx.Err() without waiting for that call to finish. The underlying
|
||||||
|
// reader may still complete asynchronously afterwards.
|
||||||
|
func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total := 0
|
||||||
|
for total < len(buf) {
|
||||||
|
chunk, err := readContext(ctx, reader, len(buf)-total)
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
total += copy(buf[total:], chunk)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
if total > 0 {
|
||||||
|
return total, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return total, io.EOF
|
||||||
|
}
|
||||||
|
return total, err
|
||||||
|
}
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return total, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFullContext writes the full payload unless the context is canceled or
|
||||||
|
// the underlying writer returns an error.
|
||||||
|
//
|
||||||
|
// If ctx is canceled while the underlying Write call is blocked,
|
||||||
|
// WriteFullContext returns ctx.Err() without waiting for that call to finish.
|
||||||
|
// The underlying writer may still complete asynchronously afterwards.
|
||||||
|
func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||||||
|
if writer == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total := 0
|
||||||
|
for total < len(data) {
|
||||||
|
written, err := writeContext(ctx, writer, data[total:])
|
||||||
|
if written > 0 {
|
||||||
|
total += written
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return total, err
|
||||||
|
}
|
||||||
|
if written == 0 {
|
||||||
|
return total, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyContext copies from src to dst until EOF, context cancellation, or a
|
||||||
|
// non-EOF error occurs.
|
||||||
|
//
|
||||||
|
// If ctx is canceled while the current read or write is blocked, CopyContext
|
||||||
|
// returns ctx.Err() without waiting for that operation to finish.
|
||||||
|
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||||
|
if dst == nil || src == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
var copied int64
|
||||||
|
for {
|
||||||
|
chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize)
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
written, writeErr := WriteFullContext(ctx, dst, chunk)
|
||||||
|
copied += int64(written)
|
||||||
|
if writeErr != nil {
|
||||||
|
return copied, writeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
if errors.Is(readErr, io.EOF) {
|
||||||
|
return copied, nil
|
||||||
|
}
|
||||||
|
return copied, readErr
|
||||||
|
}
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return copied, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
resultCh := make(chan readContextResult, 1)
|
||||||
|
go func() {
|
||||||
|
tmp := make([]byte, size)
|
||||||
|
n, err := reader.Read(tmp)
|
||||||
|
if n > 0 {
|
||||||
|
tmp = tmp[:n]
|
||||||
|
} else {
|
||||||
|
tmp = nil
|
||||||
|
}
|
||||||
|
resultCh <- readContextResult{data: tmp, err: err}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.data, result.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.data, result.err
|
||||||
|
default:
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
payload := append([]byte(nil), data...)
|
||||||
|
resultCh := make(chan writeContextResult, 1)
|
||||||
|
go func() {
|
||||||
|
n, err := writer.Write(payload)
|
||||||
|
resultCh <- writeContextResult{n: n, err: err}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.n, result.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.n, result.err
|
||||||
|
default:
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
153
io_context_test.go
Normal file
153
io_context_test.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chunkedWriter struct {
|
||||||
|
buf bytes.Buffer
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (writer *chunkedWriter) Write(p []byte) (int, error) {
|
||||||
|
if writer.max <= 0 || len(p) <= writer.max {
|
||||||
|
return writer.buf.Write(p)
|
||||||
|
}
|
||||||
|
return writer.buf.Write(p[:writer.max])
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockingReader struct {
|
||||||
|
started chan struct{}
|
||||||
|
release chan struct{}
|
||||||
|
data []byte
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *blockingReader) Read(p []byte) (int, error) {
|
||||||
|
reader.once.Do(func() { close(reader.started) })
|
||||||
|
<-reader.release
|
||||||
|
n := copy(p, reader.data)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockingWriter struct {
|
||||||
|
started chan struct{}
|
||||||
|
release chan struct{}
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (writer *blockingWriter) Write(p []byte) (int, error) {
|
||||||
|
writer.once.Do(func() { close(writer.started) })
|
||||||
|
<-writer.release
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFullContext(t *testing.T) {
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFullContext returned error: %v", err)
|
||||||
|
}
|
||||||
|
if n != 5 || string(buf) != "hello" {
|
||||||
|
t.Fatalf("unexpected payload: n=%d data=%q", n, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) {
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf)
|
||||||
|
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
t.Fatalf("expected unexpected EOF, got %v", err)
|
||||||
|
}
|
||||||
|
if n != 3 || string(buf[:n]) != "hey" {
|
||||||
|
t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFullContextCanceledWhileBlocked(t *testing.T) {
|
||||||
|
reader := &blockingReader{
|
||||||
|
started: make(chan struct{}),
|
||||||
|
release: make(chan struct{}),
|
||||||
|
data: []byte("hello"),
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
_, err := ReadFullContext(ctx, reader, buf)
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-reader.started
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context canceled, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("ReadFullContext did not return after cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(reader.release)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFullContext(t *testing.T) {
|
||||||
|
writer := &chunkedWriter{max: 2}
|
||||||
|
n, err := WriteFullContext(context.Background(), writer, []byte("hello"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteFullContext returned error: %v", err)
|
||||||
|
}
|
||||||
|
if n != 5 || writer.buf.String() != "hello" {
|
||||||
|
t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFullContextCanceledWhileBlocked(t *testing.T) {
|
||||||
|
writer := &blockingWriter{
|
||||||
|
started: make(chan struct{}),
|
||||||
|
release: make(chan struct{}),
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := WriteFullContext(ctx, writer, []byte("hello"))
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-writer.started
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context canceled, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("WriteFullContext did not return after cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(writer.release)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyContext(t *testing.T) {
|
||||||
|
var dst bytes.Buffer
|
||||||
|
written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CopyContext returned error: %v", err)
|
||||||
|
}
|
||||||
|
if written != int64(len("hello world")) || dst.String() != "hello world" {
|
||||||
|
t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
150
io_test.go
150
io_test.go
@ -1,10 +1,32 @@
|
|||||||
package stario
|
package stario
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func installRawInputStub(t *testing.T, input string, signalErr error) {
|
||||||
|
t.Helper()
|
||||||
|
prevFactory := rawTerminalSessionFactory
|
||||||
|
prevSignalHandler := inputSignalHandler
|
||||||
|
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
|
||||||
|
return &rawTerminalSession{
|
||||||
|
reader: bufio.NewReader(strings.NewReader(input)),
|
||||||
|
redrawHint: promptRedrawHint(hint),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
inputSignalHandler = func(sigtype rune) error {
|
||||||
|
return signalErr
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
rawTerminalSessionFactory = prevFactory
|
||||||
|
inputSignalHandler = prevSignalHandler
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestPromptRedrawHint(t *testing.T) {
|
func TestPromptRedrawHint(t *testing.T) {
|
||||||
got := promptRedrawHint("头部提示\n 中文确认: ")
|
got := promptRedrawHint("头部提示\n 中文确认: ")
|
||||||
if got != "中文确认:" {
|
if got != "中文确认:" {
|
||||||
@ -19,6 +41,81 @@ func TestStringDisplayWidth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRawEchoRenderUnitPlainRune(t *testing.T) {
|
||||||
|
text, width, ok := rawEchoRenderUnit('中', "", 0)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected plain wide rune to use fast path")
|
||||||
|
}
|
||||||
|
if text != "中" || width != 2 {
|
||||||
|
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "中", 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRawEchoRenderUnitMaskedRune(t *testing.T) {
|
||||||
|
text, width, ok := rawEchoRenderUnit('a', "[]", stringDisplayWidth("[]"))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected masked rune to use fast path")
|
||||||
|
}
|
||||||
|
if text != "[]" || width != 2 {
|
||||||
|
t.Fatalf("unexpected render unit: got (%q, %d) want (%q, %d)", text, width, "[]", 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRawEchoRenderUnitFallsBackForControlRune(t *testing.T) {
|
||||||
|
if _, _, ok := rawEchoRenderUnit('\x00', "", 0); ok {
|
||||||
|
t.Fatal("expected control rune to fall back to full redraw")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalInputResultExitSuppressesError(t *testing.T) {
|
||||||
|
got := signalInputResult(rawInputSignalExit, ErrSignalInterrupt)
|
||||||
|
if got.err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", got.err)
|
||||||
|
}
|
||||||
|
if got.msg != "" {
|
||||||
|
t.Fatalf("expected empty message, got %q", got.msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalInputResultReturnErrorPreservesSignal(t *testing.T) {
|
||||||
|
got := signalInputResult(rawInputSignalReturnError, ErrSignalInterrupt)
|
||||||
|
if !errors.Is(got.err, ErrSignalInterrupt) {
|
||||||
|
t.Fatalf("expected signal error, got %v", got.err)
|
||||||
|
}
|
||||||
|
if got.msg != "" {
|
||||||
|
t.Fatalf("expected empty message, got %q", got.msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswdSuppressesSignalError(t *testing.T) {
|
||||||
|
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
|
||||||
|
got := Passwd("", "fallback")
|
||||||
|
if got.err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", got.err)
|
||||||
|
}
|
||||||
|
if got.msg != "" {
|
||||||
|
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswdResponseSignalPreservesSignalError(t *testing.T) {
|
||||||
|
installRawInputStub(t, string([]rune{0x03}), ErrSignalInterrupt)
|
||||||
|
got := PasswdResponseSignal("", "fallback")
|
||||||
|
if !errors.Is(got.err, ErrSignalInterrupt) {
|
||||||
|
t.Fatalf("expected interrupt error, got %v", got.err)
|
||||||
|
}
|
||||||
|
if got.msg != "" {
|
||||||
|
t.Fatalf("expected empty message after signal exit, got %q", got.msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopUntilEmptyTriggerReturnsAfterFirstKey(t *testing.T) {
|
||||||
|
installRawInputStub(t, "abc", nil)
|
||||||
|
if err := StopUntil("", "", false); err != nil {
|
||||||
|
t.Fatalf("StopUntil returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseYesNoValue(t *testing.T) {
|
func TestParseYesNoValue(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -40,6 +137,59 @@ func TestParseYesNoValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSliceFloat32(t *testing.T) {
|
||||||
|
data := InputMsg{msg: "1.5,2.25", err: nil}
|
||||||
|
got, err := data.SliceFloat32(",")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SliceFloat32 returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(got) != 2 || got[0] != float32(1.5) || got[1] != float32(2.25) {
|
||||||
|
t.Fatalf("unexpected float32 slice: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypedSliceParsingTrimsTokenWhitespace(t *testing.T) {
|
||||||
|
ints, err := (InputMsg{msg: "1, 2, 3"}).SliceInt(",")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SliceInt returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(ints) != 3 || ints[0] != 1 || ints[1] != 2 || ints[2] != 3 {
|
||||||
|
t.Fatalf("unexpected int slice: %#v", ints)
|
||||||
|
}
|
||||||
|
|
||||||
|
bools, err := (InputMsg{msg: "true, false, true"}).SliceBool(",")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SliceBool returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(bools) != 3 || !bools[0] || bools[1] || !bools[2] {
|
||||||
|
t.Fatalf("unexpected bool slice: %#v", bools)
|
||||||
|
}
|
||||||
|
|
||||||
|
float64s, err := (InputMsg{msg: "1.25, 2.5, 3.75"}).SliceFloat64(",")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SliceFloat64 returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(float64s) != 3 || float64s[0] != 1.25 || float64s[1] != 2.5 || float64s[2] != 3.75 {
|
||||||
|
t.Fatalf("unexpected float64 slice: %#v", float64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdvanceTriggerIndexHandlesOverlap(t *testing.T) {
|
||||||
|
trigger := []rune("aba")
|
||||||
|
prefix := buildTriggerPrefixTable(trigger)
|
||||||
|
index := 0
|
||||||
|
complete := false
|
||||||
|
for _, r := range []rune("aaba") {
|
||||||
|
index, complete = advanceTriggerIndex(trigger, prefix, index, r)
|
||||||
|
if complete {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !complete {
|
||||||
|
t.Fatal("expected overlapped trigger to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_Slice(t *testing.T) {
|
func Test_Slice(t *testing.T) {
|
||||||
var data = InputMsg{
|
var data = InputMsg{
|
||||||
msg: "true,false,true,true,false,0,1,hello",
|
msg: "true,false,true,true,false,0,1,hello",
|
||||||
|
|||||||
71
pipe.go
Normal file
71
pipe.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// StarPipeReader is the read side returned by NewStarPipe.
|
||||||
|
type StarPipeReader struct {
|
||||||
|
buf *StarBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarPipeWriter is the write side returned by NewStarPipe.
|
||||||
|
type StarPipeWriter struct {
|
||||||
|
buf *StarBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStarPipe creates a buffered in-memory pipe backed by StarBuffer.
|
||||||
|
//
|
||||||
|
// The writer side uses Close to signal a graceful end-of-stream and Abort to
|
||||||
|
// fail both sides immediately.
|
||||||
|
func NewStarPipe(capacity uint64) (*StarPipeReader, *StarPipeWriter, error) {
|
||||||
|
buf, err := NewStarBuffer(capacity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return &StarPipeReader{buf: buf}, &StarPipeWriter{buf: buf}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads buffered bytes from the pipe.
|
||||||
|
func (reader *StarPipeReader) Read(p []byte) (int, error) {
|
||||||
|
if reader == nil || reader.buf == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return reader.buf.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close aborts the pipe from the read side and wakes blocked writers.
|
||||||
|
func (reader *StarPipeReader) Close() error {
|
||||||
|
if reader == nil || reader.buf == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return reader.buf.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort aborts the pipe from the read side and wakes blocked writers.
|
||||||
|
func (reader *StarPipeReader) Abort() error {
|
||||||
|
return reader.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes bytes into the pipe buffer.
|
||||||
|
func (writer *StarPipeWriter) Write(p []byte) (int, error) {
|
||||||
|
if writer == nil || writer.buf == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.buf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close gracefully closes the write side. Buffered bytes remain readable until
|
||||||
|
// drained.
|
||||||
|
func (writer *StarPipeWriter) Close() error {
|
||||||
|
if writer == nil || writer.buf == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.buf.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort aborts the pipe immediately.
|
||||||
|
func (writer *StarPipeWriter) Abort() error {
|
||||||
|
if writer == nil || writer.buf == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return writer.buf.Abort()
|
||||||
|
}
|
||||||
53
pipe_test.go
Normal file
53
pipe_test.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStarPipeRoundTrip(t *testing.T) {
|
||||||
|
reader, writer, err := NewStarPipe(16)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := []byte("hello world")
|
||||||
|
go func() {
|
||||||
|
_, _ = writer.Write(want)
|
||||||
|
_ = writer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected payload: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarPipeReaderCloseAbortsWriter(t *testing.T) {
|
||||||
|
reader, writer, err := NewStarPipe(4)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := reader.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := writer.Write([]byte("x")); !errors.Is(err, ErrStarBufferClosed) {
|
||||||
|
t.Fatalf("unexpected writer error after reader close: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarPipeNilEndsReportClosedPipe(t *testing.T) {
|
||||||
|
var reader *StarPipeReader
|
||||||
|
var writer *StarPipeWriter
|
||||||
|
|
||||||
|
if _, err := reader.Read(make([]byte, 1)); !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("unexpected nil reader error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := writer.Write([]byte("x")); !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("unexpected nil writer error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
325
que.go
325
que.go
@ -1,325 +0,0 @@
|
|||||||
package stario
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
|
|
||||||
|
|
||||||
// 识别头
|
|
||||||
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
|
||||||
|
|
||||||
// MsgQueue 为基本的信息单位
|
|
||||||
type MsgQueue struct {
|
|
||||||
ID uint16
|
|
||||||
Msg []byte
|
|
||||||
Conn interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StarQueue 为流数据中的消息队列分发
|
|
||||||
type StarQueue struct {
|
|
||||||
maxLength uint32
|
|
||||||
count int64
|
|
||||||
Encode bool
|
|
||||||
msgID uint16
|
|
||||||
msgPool chan MsgQueue
|
|
||||||
unFinMsg sync.Map
|
|
||||||
lastID int //= -1
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
duration time.Duration
|
|
||||||
EncodeFunc func([]byte) []byte
|
|
||||||
DecodeFunc func([]byte) []byte
|
|
||||||
//restoreMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
|
|
||||||
var q StarQueue
|
|
||||||
q.Encode = false
|
|
||||||
q.count = count
|
|
||||||
q.maxLength = maxMsgLength
|
|
||||||
q.msgPool = make(chan MsgQueue, count)
|
|
||||||
if ctx == nil {
|
|
||||||
q.ctx, q.cancel = context.WithCancel(context.Background())
|
|
||||||
} else {
|
|
||||||
q.ctx, q.cancel = context.WithCancel(ctx)
|
|
||||||
}
|
|
||||||
q.duration = 0
|
|
||||||
return &q
|
|
||||||
}
|
|
||||||
func NewQueueWithCount(count int64) *StarQueue {
|
|
||||||
return NewQueueCtx(nil, count, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQueue 建立一个新消息队列
|
|
||||||
func NewQueue() *StarQueue {
|
|
||||||
return NewQueueWithCount(32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint32ToByte 4位uint32转byte
|
|
||||||
func Uint32ToByte(src uint32) []byte {
|
|
||||||
res := make([]byte, 4)
|
|
||||||
res[3] = uint8(src)
|
|
||||||
res[2] = uint8(src >> 8)
|
|
||||||
res[1] = uint8(src >> 16)
|
|
||||||
res[0] = uint8(src >> 24)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint32 byte转4位uint32
|
|
||||||
func ByteToUint32(src []byte) uint32 {
|
|
||||||
var res uint32
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint16ToByte 2位uint16转byte
|
|
||||||
func Uint16ToByte(src uint16) []byte {
|
|
||||||
res := make([]byte, 2)
|
|
||||||
res[1] = uint8(src)
|
|
||||||
res[0] = uint8(src >> 8)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint16 用于byte转uint16
|
|
||||||
func ByteToUint16(src []byte) uint16 {
|
|
||||||
var res uint16
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildMessage 生成编码后的信息用于发送
|
|
||||||
func (q *StarQueue) BuildMessage(src []byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
q.msgID++
|
|
||||||
if q.Encode {
|
|
||||||
src = q.EncodeFunc(src)
|
|
||||||
}
|
|
||||||
length := uint32(len(src))
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(q.msgID))
|
|
||||||
buff.Write(src)
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildHeader 生成编码后的Header用于发送
|
|
||||||
func (q *StarQueue) BuildHeader(length uint32) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
q.msgID++
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(q.msgID))
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
type unFinMsg struct {
|
|
||||||
ID uint16
|
|
||||||
LengthRecv uint32
|
|
||||||
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
|
|
||||||
HeaderMsg []byte
|
|
||||||
RecvMsg []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *StarQueue) push2list(msg MsgQueue) {
|
|
||||||
q.msgPool <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseMessage 用于解析收到的msg信息
|
|
||||||
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMessage 用于解析收到的msg信息
|
|
||||||
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
|
|
||||||
tmp, ok := q.unFinMsg.Load(conn)
|
|
||||||
if ok { //存在未完成的信息
|
|
||||||
lastMsg := tmp.(*unFinMsg)
|
|
||||||
headerLen := len(lastMsg.HeaderMsg)
|
|
||||||
if headerLen < 14 { //未完成头标题
|
|
||||||
//传输的数据不能填充header头
|
|
||||||
if len(msg) < 14-headerLen {
|
|
||||||
//加入header头并退出
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
|
|
||||||
q.unFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
//获取14字节完整的header
|
|
||||||
header := msg[0 : 14-headerLen]
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
|
|
||||||
//检查收到的header是否为认证header
|
|
||||||
//若不是,丢弃并重新来过
|
|
||||||
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
|
|
||||||
q.unFinMsg.Delete(conn)
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
//获得本数据包长度
|
|
||||||
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
|
|
||||||
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
|
|
||||||
q.unFinMsg.Delete(conn)
|
|
||||||
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
|
|
||||||
}
|
|
||||||
//获得本数据包ID
|
|
||||||
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
|
|
||||||
//存入列表
|
|
||||||
q.unFinMsg.Store(conn, lastMsg)
|
|
||||||
msg = msg[14-headerLen:]
|
|
||||||
if uint32(len(msg)) < lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg
|
|
||||||
q.unFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if uint32(len(msg)) >= lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
|
|
||||||
if q.Encode {
|
|
||||||
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
msg = msg[lastMsg.LengthRecv:]
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
//q.restoreMu.Lock()
|
|
||||||
q.push2list(storeMsg)
|
|
||||||
//q.restoreMu.Unlock()
|
|
||||||
q.unFinMsg.Delete(conn)
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
|
|
||||||
if lastID < 0 {
|
|
||||||
q.unFinMsg.Delete(conn)
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
if len(msg) >= lastID {
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
|
|
||||||
if q.Encode {
|
|
||||||
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
q.push2list(storeMsg)
|
|
||||||
q.unFinMsg.Delete(conn)
|
|
||||||
if len(msg) == lastID {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
msg = msg[lastID:]
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
|
|
||||||
q.unFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var start int
|
|
||||||
if start = searchHeader(msg); start == -1 {
|
|
||||||
return errors.New("data format error")
|
|
||||||
}
|
|
||||||
msg = msg[start:]
|
|
||||||
lastMsg := unFinMsg{}
|
|
||||||
q.unFinMsg.Store(conn, &lastMsg)
|
|
||||||
return q.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkHeader(msg []byte) bool {
|
|
||||||
if len(msg) != 8 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
if v != header[k] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func searchHeader(msg []byte) int {
|
|
||||||
if len(msg) < 8 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
find := 0
|
|
||||||
if v == header[0] {
|
|
||||||
for k2, v2 := range header {
|
|
||||||
if msg[k+k2] == v2 {
|
|
||||||
find++
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if find == 8 {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func bytesMerge(src ...[]byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
for _, v := range src {
|
|
||||||
buff.Write(v)
|
|
||||||
}
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore 获取收到的信息
|
|
||||||
func (q *StarQueue) Restore() (MsgQueue, error) {
|
|
||||||
if q.duration.Seconds() == 0 {
|
|
||||||
q.duration = 86400 * time.Second
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-q.ctx.Done():
|
|
||||||
return MsgQueue{}, errors.New("Stoped By External Function Call")
|
|
||||||
case <-time.After(q.duration):
|
|
||||||
if q.duration != 0 {
|
|
||||||
return MsgQueue{}, ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
case data, ok := <-q.msgPool:
|
|
||||||
if !ok {
|
|
||||||
return MsgQueue{}, os.ErrClosed
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreOne 获取收到的一个信息
|
|
||||||
// 兼容性修改
|
|
||||||
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
|
|
||||||
return q.Restore()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop 立即停止Restore
|
|
||||||
func (q *StarQueue) Stop() {
|
|
||||||
q.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreDuration Restore最大超时时间
|
|
||||||
func (q *StarQueue) RestoreDuration(tm time.Duration) {
|
|
||||||
q.duration = tm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
|
|
||||||
return q.msgPool
|
|
||||||
}
|
|
||||||
80
que_benchmark_test.go
Normal file
80
que_benchmark_test.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkQueueBuildMessage64KiB(b *testing.B) {
|
||||||
|
que := NewQueue()
|
||||||
|
payload := make([]byte, 64*1024)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payload)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = que.BuildMessage(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkQueueWriteFrame64KiB(b *testing.B) {
|
||||||
|
que := NewQueue()
|
||||||
|
payload := make([]byte, 64*1024)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payload)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := que.WriteFrame(io.Discard, payload); err != nil {
|
||||||
|
b.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkQueueWriteFrameBuffers64KiB(b *testing.B) {
|
||||||
|
que := NewQueue()
|
||||||
|
payload := make([]byte, 64*1024)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payload)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := que.WriteFrameBuffers(io.Discard, payload); err != nil {
|
||||||
|
b.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkQueueWriteFramesBuffers4x64KiB(b *testing.B) {
|
||||||
|
que := NewQueue()
|
||||||
|
payloads := [][]byte{
|
||||||
|
make([]byte, 64*1024),
|
||||||
|
make([]byte, 64*1024),
|
||||||
|
make([]byte, 64*1024),
|
||||||
|
make([]byte, 64*1024),
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(payloads) * len(payloads[0])))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := que.WriteFramesBuffers(io.Discard, payloads...); err != nil {
|
||||||
|
b.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkQueueParseRestoreHello(b *testing.B) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
if len(frame) == 0 {
|
||||||
|
b.Fatal("BuildMessage returned empty frame")
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(frame)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := que.ParseMessage(frame, "bench"); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := que.Restore(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
34
que_bytes.go
Normal file
34
que_bytes.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
// Uint32ToByte 4位uint32转byte。
|
||||||
|
func Uint32ToByte(src uint32) []byte {
|
||||||
|
res := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(res, src)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByteToUint32 byte转4位uint32。
|
||||||
|
func ByteToUint32(src []byte) uint32 {
|
||||||
|
return binary.BigEndian.Uint32(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint16ToByte 2位uint16转byte。
|
||||||
|
func Uint16ToByte(src uint16) []byte {
|
||||||
|
res := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(res, src)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByteToUint16 用于byte转uint16。
|
||||||
|
func ByteToUint16(src []byte) uint16 {
|
||||||
|
return binary.BigEndian.Uint16(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneBytes(src []byte) []byte {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]byte(nil), src...)
|
||||||
|
}
|
||||||
208
que_frame.go
Normal file
208
que_frame.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildMessage builds one full frame and panics if the payload is too large to
|
||||||
|
// fit in the framing format.
|
||||||
|
//
|
||||||
|
// New code should prefer BuildMessageErr when it needs an explicit error path.
|
||||||
|
func (q *StarQueue) BuildMessage(src []byte) []byte {
|
||||||
|
frame, err := q.BuildMessageErr(src)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildMessageErr builds one full frame and returns an explicit error when the
|
||||||
|
// payload exceeds the 32-bit frame length.
|
||||||
|
func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) {
|
||||||
|
payload := src
|
||||||
|
if q.Encode && q.EncodeFunc != nil {
|
||||||
|
payload = q.EncodeFunc(payload)
|
||||||
|
}
|
||||||
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
header := q.BuildHeader(length)
|
||||||
|
frame := make([]byte, 0, len(header)+len(payload))
|
||||||
|
frame = append(frame, header...)
|
||||||
|
frame = append(frame, payload...)
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFrame writes one framed payload directly to w without building an
|
||||||
|
// intermediate full-frame slice first.
|
||||||
|
func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error {
|
||||||
|
if w == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
payload := src
|
||||||
|
if q.Encode && q.EncodeFunc != nil {
|
||||||
|
payload = q.EncodeFunc(payload)
|
||||||
|
}
|
||||||
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var header [queHeaderSize]byte
|
||||||
|
writeHeaderBytes(header[:], queHeader{
|
||||||
|
Length: length,
|
||||||
|
Version: queVersionV1,
|
||||||
|
Flags: queSupportedFlags,
|
||||||
|
})
|
||||||
|
if err := writeFull(w, header[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeFull(w, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFrameBuffers writes one framed payload using net.Buffers so callers can
|
||||||
|
// opt into gather writes when the underlying writer supports it well.
|
||||||
|
func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error {
|
||||||
|
if w == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
payload := src
|
||||||
|
if q.Encode && q.EncodeFunc != nil {
|
||||||
|
payload = q.EncodeFunc(payload)
|
||||||
|
}
|
||||||
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var header [queHeaderSize]byte
|
||||||
|
writeHeaderBytes(header[:], queHeader{
|
||||||
|
Length: length,
|
||||||
|
Version: queVersionV1,
|
||||||
|
Flags: queSupportedFlags,
|
||||||
|
})
|
||||||
|
buffers := net.Buffers{header[:], payload}
|
||||||
|
_, err = buffers.WriteTo(w)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFramesBuffers writes multiple framed payloads using one net.Buffers
|
||||||
|
// batch so callers can reduce write calls on stream transports.
|
||||||
|
func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error {
|
||||||
|
if w == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if len(payloads) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
buffers := make(net.Buffers, 0, len(payloads)*2)
|
||||||
|
headers := make([][queHeaderSize]byte, len(payloads))
|
||||||
|
for i, src := range payloads {
|
||||||
|
payload := src
|
||||||
|
if q.Encode && q.EncodeFunc != nil {
|
||||||
|
payload = q.EncodeFunc(payload)
|
||||||
|
}
|
||||||
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writeHeaderBytes(headers[i][:], queHeader{
|
||||||
|
Length: length,
|
||||||
|
Version: queVersionV1,
|
||||||
|
Flags: queSupportedFlags,
|
||||||
|
})
|
||||||
|
buffers = append(buffers, headers[i][:], payload)
|
||||||
|
}
|
||||||
|
_, err := buffers.WriteTo(w)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildHeader 生成编码后的Header用于发送。
|
||||||
|
func (q *StarQueue) BuildHeader(length uint32) []byte {
|
||||||
|
return buildHeaderBytes(queHeader{
|
||||||
|
Length: length,
|
||||||
|
Version: queVersionV1,
|
||||||
|
Flags: queSupportedFlags,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildHeaderBytes(header queHeader) []byte {
|
||||||
|
buf := make([]byte, queHeaderSize)
|
||||||
|
writeHeaderBytes(buf, header)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeHeaderBytes(dst []byte, header queHeader) {
|
||||||
|
if len(dst) < queHeaderSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
copy(dst[:queMagicSize], queMagic)
|
||||||
|
binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length)
|
||||||
|
dst[12] = header.Version
|
||||||
|
dst[13] = header.Flags
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadSizeToUint32(size uint64) (uint32, error) {
|
||||||
|
const maxFramePayload = ^uint32(0)
|
||||||
|
if size > uint64(maxFramePayload) {
|
||||||
|
return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload))
|
||||||
|
}
|
||||||
|
return uint32(size), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFull(w io.Writer, data []byte) error {
|
||||||
|
for len(data) > 0 {
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if n > 0 {
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) {
|
||||||
|
if len(src) < queHeaderSize {
|
||||||
|
return queHeader{}, ErrQueueDataFormat
|
||||||
|
}
|
||||||
|
if !equalMagic(src[:queMagicSize]) {
|
||||||
|
return queHeader{}, ErrQueueDataFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
header := queHeader{
|
||||||
|
Length: ByteToUint32(src[queMagicSize : queMagicSize+4]),
|
||||||
|
Version: src[12],
|
||||||
|
Flags: src[13],
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.Version != queVersionV1 {
|
||||||
|
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version)
|
||||||
|
}
|
||||||
|
if header.Flags != queSupportedFlags {
|
||||||
|
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags)
|
||||||
|
}
|
||||||
|
if maxLength != 0 && header.Length > maxLength {
|
||||||
|
return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
return header, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalMagic(src []byte) bool {
|
||||||
|
if len(src) != queMagicSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, b := range src {
|
||||||
|
if b != queMagic[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
237
que_parse.go
Normal file
237
que_parse.go
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseMessage 用于解析收到的msg信息。
|
||||||
|
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
||||||
|
return q.parseMessage(msg, conn, false, func(payload []byte) error {
|
||||||
|
return q.push2list(MsgQueue{
|
||||||
|
Msg: payload,
|
||||||
|
Conn: conn,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMessageView parses frames and exposes each payload to fn without
|
||||||
|
// forcing StarQueue to clone it first.
|
||||||
|
//
|
||||||
|
// The provided payload is only valid during the callback. If q uses the legacy
|
||||||
|
// DecodeFunc path, StarQueue still has to allocate a decoded payload first.
|
||||||
|
func (q *StarQueue) ParseMessageView(msg []byte, conn interface{}, fn func(FrameView) error) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrQueueFrameHandlerNil
|
||||||
|
}
|
||||||
|
return q.parseMessage(msg, conn, true, func(payload []byte) error {
|
||||||
|
return fn(FrameView{
|
||||||
|
Payload: payload,
|
||||||
|
Conn: conn,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMessageOwned parses frames and emits owned payload copies to fn without
|
||||||
|
// routing them through RestoreChan.
|
||||||
|
//
|
||||||
|
// Compared with ParseMessage, this keeps StarQueue state handling the same but
|
||||||
|
// lets callers decide how to dispatch parsed messages themselves.
|
||||||
|
func (q *StarQueue) ParseMessageOwned(msg []byte, conn interface{}, fn func(MsgQueue) error) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrQueueFrameHandlerNil
|
||||||
|
}
|
||||||
|
frames := make([]MsgQueue, 0, 1)
|
||||||
|
parseErr := q.parseMessage(msg, conn, false, func(payload []byte) error {
|
||||||
|
frames = append(frames, MsgQueue{
|
||||||
|
Msg: payload,
|
||||||
|
Conn: conn,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
for _, frame := range frames {
|
||||||
|
if err := fn(frame); err != nil {
|
||||||
|
if parseErr != nil {
|
||||||
|
return fmt.Errorf("%v: %w", parseErr, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parseErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *StarQueue) parseMessage(msg []byte, conn interface{}, borrowPayload bool, emit func([]byte) error) error {
|
||||||
|
state, err := q.connState(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var firstErr error
|
||||||
|
for {
|
||||||
|
payload, ok, err := q.nextPayload(state, conn, msg, borrowPayload)
|
||||||
|
msg = nil
|
||||||
|
if err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := emit(payload); err != nil {
|
||||||
|
if firstErr != nil {
|
||||||
|
return fmt.Errorf("%v: %w", firstErr, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *StarQueue) nextPayload(state *queConnState, conn interface{}, msg []byte, borrowPayload bool) ([]byte, bool, error) {
|
||||||
|
state.mu.Lock()
|
||||||
|
defer state.mu.Unlock()
|
||||||
|
|
||||||
|
if len(msg) != 0 {
|
||||||
|
state.buf = append(state.buf, msg...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var firstErr error
|
||||||
|
for {
|
||||||
|
synced, err := syncFrameStart(&state.buf)
|
||||||
|
if err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
if !synced {
|
||||||
|
if len(state.buf) == 0 {
|
||||||
|
q.states.Delete(conn)
|
||||||
|
}
|
||||||
|
return nil, false, firstErr
|
||||||
|
}
|
||||||
|
if len(state.buf) < queHeaderSize {
|
||||||
|
return nil, false, firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := parseHeaderBytes(state.buf[:queHeaderSize], q.maxLength)
|
||||||
|
if err != nil {
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
state.buf = shrinkBuffer(state.buf[1:])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
frameLen := queHeaderSize + int(header.Length)
|
||||||
|
if len(state.buf) < frameLen {
|
||||||
|
return nil, false, firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, rest := extractPayload(state.buf, frameLen, borrowPayload && !(q.Encode && q.DecodeFunc != nil))
|
||||||
|
state.buf = rest
|
||||||
|
if q.Encode && q.DecodeFunc != nil {
|
||||||
|
payload = q.DecodeFunc(payload)
|
||||||
|
}
|
||||||
|
if len(state.buf) == 0 {
|
||||||
|
q.states.Delete(conn)
|
||||||
|
}
|
||||||
|
return payload, true, firstErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPayload(buf []byte, frameLen int, borrowPayload bool) ([]byte, []byte) {
|
||||||
|
payload := buf[queHeaderSize:frameLen]
|
||||||
|
if !borrowPayload {
|
||||||
|
return cloneBytes(payload), shrinkBuffer(buf[frameLen:])
|
||||||
|
}
|
||||||
|
if frameLen == len(buf) {
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
return payload, cloneBytes(buf[frameLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *StarQueue) push2list(msg MsgQueue) error {
|
||||||
|
select {
|
||||||
|
case <-q.ctx.Done():
|
||||||
|
return q.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
q.sendMu.RLock()
|
||||||
|
defer q.sendMu.RUnlock()
|
||||||
|
select {
|
||||||
|
case <-q.ctx.Done():
|
||||||
|
return q.ctx.Err()
|
||||||
|
case q.msgPool <- msg:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateConnKey(conn interface{}) error {
|
||||||
|
if conn == nil {
|
||||||
|
return ErrQueueConnKeyNil
|
||||||
|
}
|
||||||
|
typ := reflect.TypeOf(conn)
|
||||||
|
if typ != nil && !typ.Comparable() {
|
||||||
|
return ErrQueueConnKeyInvalid
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *StarQueue) connState(conn interface{}) (*queConnState, error) {
|
||||||
|
if err := validateConnKey(conn); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state, _ := q.states.LoadOrStore(conn, &queConnState{})
|
||||||
|
return state.(*queConnState), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncFrameStart(buf *[]byte) (bool, error) {
|
||||||
|
if len(*buf) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*buf) >= queMagicSize && equalMagic((*buf)[:queMagicSize]) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := bytes.Index(*buf, queMagic)
|
||||||
|
if idx == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if idx > 0 {
|
||||||
|
*buf = cloneBytes((*buf)[idx:])
|
||||||
|
return true, ErrQueueDataFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
keep := trailingMagicPrefixLen(*buf)
|
||||||
|
if keep == len(*buf) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if keep > 0 {
|
||||||
|
*buf = cloneBytes((*buf)[len(*buf)-keep:])
|
||||||
|
return false, ErrQueueDataFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
return false, ErrQueueDataFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
func trailingMagicPrefixLen(buf []byte) int {
|
||||||
|
max := len(buf)
|
||||||
|
if max > queMagicSize-1 {
|
||||||
|
max = queMagicSize - 1
|
||||||
|
}
|
||||||
|
for keep := max; keep > 0; keep-- {
|
||||||
|
if bytes.Equal(buf[len(buf)-keep:], queMagic[:keep]) {
|
||||||
|
return keep
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func shrinkBuffer(buf []byte) []byte {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cap(buf) > len(buf)*4 {
|
||||||
|
return cloneBytes(buf)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
80
que_runtime.go
Normal file
80
que_runtime.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (q *StarQueue) closedRestoreErr() error {
|
||||||
|
if err := q.ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore blocks until one message is available, the queue is stopped, or the
|
||||||
|
// configured timeout expires.
|
||||||
|
func (q *StarQueue) Restore() (MsgQueue, error) {
|
||||||
|
if q.duration <= 0 {
|
||||||
|
select {
|
||||||
|
case <-q.ctx.Done():
|
||||||
|
return MsgQueue{}, q.ctx.Err()
|
||||||
|
case data, ok := <-q.msgPool:
|
||||||
|
if !ok {
|
||||||
|
return MsgQueue{}, q.closedRestoreErr()
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(q.duration)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-q.ctx.Done():
|
||||||
|
return MsgQueue{}, q.ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return MsgQueue{}, ErrDeadlineExceeded
|
||||||
|
case data, ok := <-q.msgPool:
|
||||||
|
if !ok {
|
||||||
|
return MsgQueue{}, q.closedRestoreErr()
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreOne 获取收到的一个信息。
|
||||||
|
// 兼容性修改。
|
||||||
|
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
|
||||||
|
return q.Restore()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop cancels the queue runtime.
|
||||||
|
//
|
||||||
|
// After Stop returns, Restore unblocks with context.Canceled and RestoreChan is
|
||||||
|
// eventually closed.
|
||||||
|
func (q *StarQueue) Stop() {
|
||||||
|
q.cancel()
|
||||||
|
q.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *StarQueue) shutdown() {
|
||||||
|
q.stopOnce.Do(func() {
|
||||||
|
q.sendMu.Lock()
|
||||||
|
close(q.msgPool)
|
||||||
|
q.sendMu.Unlock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreDuration sets the Restore timeout. A non-positive duration means wait
|
||||||
|
// forever until a message arrives or Stop is called.
|
||||||
|
func (q *StarQueue) RestoreDuration(tm time.Duration) {
|
||||||
|
q.duration = tm
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreChan exposes the parsed message stream.
|
||||||
|
//
|
||||||
|
// The returned channel is closed after Stop or when the queue context is
|
||||||
|
// canceled. New code should still prefer Restore when it needs timeout/error
|
||||||
|
// classification instead of a plain stream.
|
||||||
|
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
|
||||||
|
return q.msgPool
|
||||||
|
}
|
||||||
536
que_test.go
536
que_test.go
@ -1,42 +1,520 @@
|
|||||||
package stario
|
package stario
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_QueSpeed(t *testing.T) {
|
func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) {
|
||||||
que := NewQueueWithCount(0)
|
que := NewQueue()
|
||||||
stop := make(chan struct{}, 1)
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
que.RestoreDuration(time.Second * 10)
|
|
||||||
var count int64
|
if len(frame) != queHeaderSize+5 {
|
||||||
|
t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(frame[:queMagicSize], queMagic) {
|
||||||
|
t.Fatalf("unexpected magic: %v", frame[:queMagicSize])
|
||||||
|
}
|
||||||
|
if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 {
|
||||||
|
t.Fatalf("unexpected payload length: got %d want 5", got)
|
||||||
|
}
|
||||||
|
if frame[12] != queVersionV1 {
|
||||||
|
t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1)
|
||||||
|
}
|
||||||
|
if frame[13] != queSupportedFlags {
|
||||||
|
t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload: %q", frame[queHeaderSize:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
want := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFrame mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
want := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
payloads := [][]byte{
|
||||||
|
[]byte("hello"),
|
||||||
|
[]byte("world"),
|
||||||
|
[]byte("batch"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var want []byte
|
||||||
|
for _, payload := range payloads {
|
||||||
|
want = append(want, que.BuildMessage(payload)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
|
||||||
|
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
que.Encode = true
|
||||||
|
que.EncodeFunc = bytes.ToUpper
|
||||||
|
want := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
que.Encode = true
|
||||||
|
que.EncodeFunc = bytes.ToUpper
|
||||||
|
want := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrameBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
que.Encode = true
|
||||||
|
que.EncodeFunc = bytes.ToUpper
|
||||||
|
payloads := [][]byte{
|
||||||
|
[]byte("hello"),
|
||||||
|
[]byte("batch"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var want []byte
|
||||||
|
for _, payload := range payloads {
|
||||||
|
want = append(want, que.BuildMessage(payload)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
|
||||||
|
t.Fatalf("WriteFramesBuffers failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := buf.Bytes(); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageSplitAcrossCalls(t *testing.T) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
if err := que.ParseMessage(frame[:3], "split"); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessage(frame[3:11], "split"); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial header: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessage(frame[11:], "split"); err != nil {
|
||||||
|
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-que.RestoreChan():
|
||||||
|
if data.ID != 0 {
|
||||||
|
t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID)
|
||||||
|
}
|
||||||
|
if data.Conn != "split" {
|
||||||
|
t.Fatalf("unexpected conn: %#v", data.Conn)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data.Msg, []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload: %q", data.Msg)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("did not restore parsed frame")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
var got [][]byte
|
||||||
|
|
||||||
|
handler := func(view FrameView) error {
|
||||||
|
got = append(got, cloneBytes(view.Payload))
|
||||||
|
if view.Conn != "split-view" {
|
||||||
|
t.Fatalf("unexpected conn: %#v", view.Conn)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial header: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("unexpected frame count: got %d want 1", len(got))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got[0], []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload: %q", got[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
var got []MsgQueue
|
||||||
|
|
||||||
|
handler := func(msg MsgQueue) error {
|
||||||
|
got = append(got, MsgQueue{
|
||||||
|
Msg: cloneBytes(msg.Msg),
|
||||||
|
Conn: msg.Conn,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial magic: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on partial header: %v", err)
|
||||||
|
}
|
||||||
|
if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil {
|
||||||
|
t.Fatalf("unexpected error on payload completion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("unexpected frame count: got %d want 1", len(got))
|
||||||
|
}
|
||||||
|
if got[0].Conn != "split-owned" {
|
||||||
|
t.Fatalf("unexpected conn: %#v", got[0].Conn)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got[0].Msg, []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload: %q", got[0].Msg)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case msg := <-que.RestoreChan():
|
||||||
|
t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
err := que.ParseMessage(append([]byte("junk"), frame...), "garbage")
|
||||||
|
if !errors.Is(err, ErrQueueDataFormat) {
|
||||||
|
t.Fatalf("expected data format error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-que.RestoreChan():
|
||||||
|
if !bytes.Equal(data.Msg, []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload after resync: %q", data.Msg)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("did not restore frame after skipping garbage")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
var got []byte
|
||||||
|
|
||||||
|
err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error {
|
||||||
|
got = cloneBytes(view.Payload)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrQueueDataFormat) {
|
||||||
|
t.Fatalf("expected data format error, got %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected payload after resync: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageViewNilHandler(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil)
|
||||||
|
if !errors.Is(err, ErrQueueFrameHandlerNil) {
|
||||||
|
t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageOwnedNilHandler(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil)
|
||||||
|
if !errors.Is(err, ErrQueueFrameHandlerNil) {
|
||||||
|
t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameNilWriter(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
err := que.WriteFrame(nil, []byte("hello"))
|
||||||
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFrameBuffersNilWriter(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
err := que.WriteFrameBuffers(nil, []byte("hello"))
|
||||||
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueWriteFramesBuffersNilWriter(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
err := que.WriteFramesBuffers(nil, []byte("hello"))
|
||||||
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
frame[12] = 2
|
||||||
|
|
||||||
|
err := que.ParseMessage(frame, "version")
|
||||||
|
if !errors.Is(err, ErrQueueUnsupportedVersion) {
|
||||||
|
t.Fatalf("expected unsupported version error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-que.RestoreChan():
|
||||||
|
t.Fatalf("unexpected restored frame: %#v", data)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) {
|
||||||
|
que := NewQueueCtx(nil, 1, 4)
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
|
||||||
|
err := que.ParseMessage(frame, "large")
|
||||||
|
if !errors.Is(err, ErrQueueMessageTooLarge) {
|
||||||
|
t.Fatalf("expected message too large error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-que.RestoreChan():
|
||||||
|
t.Fatalf("unexpected restored frame: %#v", data)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
err := que.ParseMessage(frame, []byte("not-comparable"))
|
||||||
|
if !errors.Is(err, ErrQueueConnKeyInvalid) {
|
||||||
|
t.Fatalf("expected invalid conn key error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageRejectsNilConnKey(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
frame := que.BuildMessage([]byte("hello"))
|
||||||
|
err := que.ParseMessage(frame, nil)
|
||||||
|
if !errors.Is(err, ErrQueueConnKeyNil) {
|
||||||
|
t.Fatalf("expected nil conn key error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) {
|
||||||
|
_, err := payloadSizeToUint32(uint64(^uint32(0)) + 1)
|
||||||
|
if !errors.Is(err, ErrQueueMessageTooLarge) {
|
||||||
|
t.Fatalf("expected message too large error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
que.RestoreDuration(0)
|
||||||
|
|
||||||
|
type restoreResult struct {
|
||||||
|
msg MsgQueue
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan restoreResult, 1)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
msg, err := que.Restore()
|
||||||
select {
|
resultCh <- restoreResult{msg: msg, err: err}
|
||||||
case <-stop:
|
|
||||||
//fmt.Println(count)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
_, err := que.RestoreOne()
|
|
||||||
if err == nil {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
cp := 0
|
|
||||||
stoped := time.After(time.Second * 10)
|
|
||||||
data := que.BuildMessage([]byte("hello"))
|
|
||||||
for {
|
|
||||||
select {
|
select {
|
||||||
case <-stoped:
|
case result := <-resultCh:
|
||||||
fmt.Println(count, cp)
|
t.Fatalf("Restore returned too early: %#v", result)
|
||||||
stop <- struct{}{}
|
case <-time.After(50 * time.Millisecond):
|
||||||
return
|
}
|
||||||
default:
|
|
||||||
que.ParseMessage(data, "lala")
|
if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil {
|
||||||
cp++
|
t.Fatalf("ParseMessage failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("Restore returned error: %v", result.err)
|
||||||
|
}
|
||||||
|
if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) {
|
||||||
|
t.Fatalf("unexpected restore result: %#v", result.msg)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("Restore did not return after message arrival")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
resultCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := que.Restore()
|
||||||
|
resultCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
que.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-resultCh:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("Restore did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueRestoreChanClosesOnStop(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
resultCh := make(chan bool, 1)
|
||||||
|
go func() {
|
||||||
|
_, ok := <-que.RestoreChan()
|
||||||
|
resultCh <- ok
|
||||||
|
}()
|
||||||
|
|
||||||
|
que.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ok := <-resultCh:
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected RestoreChan to close after Stop")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("RestoreChan did not close after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
que := NewQueueCtx(ctx, 1, 0)
|
||||||
|
resultCh := make(chan bool, 1)
|
||||||
|
go func() {
|
||||||
|
_, ok := <-que.RestoreChan()
|
||||||
|
resultCh <- ok
|
||||||
|
}()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ok := <-resultCh:
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected RestoreChan to close after context cancel")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("RestoreChan did not close after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) {
|
||||||
|
que := NewQueueWithCount(1)
|
||||||
|
if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil {
|
||||||
|
t.Fatalf("ParseMessage first failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatalf("ParseMessage returned before Stop: %v", err)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
que.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("ParseMessage did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) {
|
||||||
|
que := NewQueue()
|
||||||
|
reentered := false
|
||||||
|
|
||||||
|
err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error {
|
||||||
|
if !bytes.Equal(view.Payload, []byte("outer")) {
|
||||||
|
t.Fatalf("unexpected outer payload: %q", view.Payload)
|
||||||
|
}
|
||||||
|
return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error {
|
||||||
|
reentered = true
|
||||||
|
if !bytes.Equal(inner.Payload, []byte("inner")) {
|
||||||
|
t.Fatalf("unexpected inner payload: %q", inner.Payload)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseMessageView failed: %v", err)
|
||||||
|
}
|
||||||
|
if !reentered {
|
||||||
|
t.Fatal("expected reentrant ParseMessageView to run")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
101
que_types.go
Normal file
101
que_types.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrDeadlineExceeded = errors.New("deadline exceeded")
|
||||||
|
var ErrQueueDataFormat = errors.New("data format error")
|
||||||
|
var ErrQueueUnsupportedVersion = errors.New("unsupported frame version")
|
||||||
|
var ErrQueueUnsupportedFlags = errors.New("unsupported frame flags")
|
||||||
|
var ErrQueueMessageTooLarge = errors.New("message too large")
|
||||||
|
var ErrQueueFrameHandlerNil = errors.New("frame handler is nil")
|
||||||
|
var ErrQueueConnKeyInvalid = errors.New("queue conn key must be comparable")
|
||||||
|
var ErrQueueConnKeyNil = errors.New("queue conn key must not be nil")
|
||||||
|
|
||||||
|
const (
|
||||||
|
queMagicSize = 8
|
||||||
|
queHeaderSize = 14
|
||||||
|
queVersionV1 = 1
|
||||||
|
queSupportedFlags = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// 识别头
|
||||||
|
var queMagic = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
||||||
|
|
||||||
|
// MsgQueue 为基本的信息单位。
|
||||||
|
type MsgQueue struct {
|
||||||
|
// Deprecated: frame-level IDs are no longer emitted by StarQueue v2 framing.
|
||||||
|
ID uint16
|
||||||
|
|
||||||
|
Msg []byte
|
||||||
|
Conn interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FrameView exposes a parsed payload without forcing StarQueue to clone it.
|
||||||
|
//
|
||||||
|
// The payload is only valid during the ParseMessageView callback. Callers must
|
||||||
|
// copy it if they need to keep it after the callback returns.
|
||||||
|
type FrameView struct {
|
||||||
|
Payload []byte
|
||||||
|
Conn interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarQueue 为流数据中的消息队列分发。
|
||||||
|
type StarQueue struct {
|
||||||
|
maxLength uint32
|
||||||
|
msgPool chan MsgQueue
|
||||||
|
states sync.Map
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
duration time.Duration
|
||||||
|
sendMu sync.RWMutex
|
||||||
|
stopOnce sync.Once
|
||||||
|
|
||||||
|
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||||
|
Encode bool
|
||||||
|
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||||
|
EncodeFunc func([]byte) []byte
|
||||||
|
// Deprecated: new code should keep StarQueue focused on framing only.
|
||||||
|
DecodeFunc func([]byte) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type queHeader struct {
|
||||||
|
Length uint32
|
||||||
|
Version uint8
|
||||||
|
Flags uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
type queConnState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
|
||||||
|
if count < 0 {
|
||||||
|
panic("stario: negative queue count")
|
||||||
|
}
|
||||||
|
q := &StarQueue{
|
||||||
|
maxLength: maxMsgLength,
|
||||||
|
msgPool: make(chan MsgQueue, count),
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
q.ctx, q.cancel = context.WithCancel(context.Background())
|
||||||
|
} else {
|
||||||
|
q.ctx, q.cancel = context.WithCancel(ctx)
|
||||||
|
}
|
||||||
|
context.AfterFunc(q.ctx, q.shutdown)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueueWithCount(count int64) *StarQueue {
|
||||||
|
return NewQueueCtx(nil, count, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewQueue 建立一个新消息队列。
|
||||||
|
func NewQueue() *StarQueue {
|
||||||
|
return NewQueueWithCount(32)
|
||||||
|
}
|
||||||
44
signal_error.go
Normal file
44
signal_error.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSignalInterrupt = errors.New("interrupt")
|
||||||
|
ErrSignalStop = errors.New("stop")
|
||||||
|
ErrSignalQuit = errors.New("quit")
|
||||||
|
)
|
||||||
|
|
||||||
|
type inputSignalError struct {
|
||||||
|
msg string
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *inputSignalError) Error() string {
|
||||||
|
return e.msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *inputSignalError) Unwrap() error {
|
||||||
|
return e.cause
|
||||||
|
}
|
||||||
|
|
||||||
|
func signalErrorForType(sigtype rune) error {
|
||||||
|
switch sigtype {
|
||||||
|
case 0x03:
|
||||||
|
return &inputSignalError{
|
||||||
|
msg: "SIGNAL SIGINT RECIVED",
|
||||||
|
cause: ErrSignalInterrupt,
|
||||||
|
}
|
||||||
|
case 0x1a:
|
||||||
|
return &inputSignalError{
|
||||||
|
msg: "SIGNAL SIGSTOP RECIVED",
|
||||||
|
cause: ErrSignalStop,
|
||||||
|
}
|
||||||
|
case 0x1c:
|
||||||
|
return &inputSignalError{
|
||||||
|
msg: "SIGNAL SIGQUIT RECIVED",
|
||||||
|
cause: ErrSignalQuit,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
50
signal_error_test.go
Normal file
50
signal_error_test.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignalErrorForType(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
sig rune
|
||||||
|
msg string
|
||||||
|
want error
|
||||||
|
}{
|
||||||
|
{name: "interrupt", sig: 0x03, msg: "SIGNAL SIGINT RECIVED", want: ErrSignalInterrupt},
|
||||||
|
{name: "stop", sig: 0x1a, msg: "SIGNAL SIGSTOP RECIVED", want: ErrSignalStop},
|
||||||
|
{name: "quit", sig: 0x1c, msg: "SIGNAL SIGQUIT RECIVED", want: ErrSignalQuit},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
err := signalErrorForType(tc.sig)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("%s: expected non-nil error", tc.name)
|
||||||
|
}
|
||||||
|
if err.Error() != tc.msg {
|
||||||
|
t.Fatalf("%s: unexpected error text: got %q want %q", tc.name, err.Error(), tc.msg)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, tc.want) {
|
||||||
|
t.Fatalf("%s: errors.Is mismatch: got %v want %v", tc.name, err, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalErrorForUnknownType(t *testing.T) {
|
||||||
|
if err := signalErrorForType('x'); err != nil {
|
||||||
|
t.Fatalf("expected nil error for unknown signal, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalReturnsTypedError(t *testing.T) {
|
||||||
|
err := signal(0x03)
|
||||||
|
if !errors.Is(err, ErrSignalInterrupt) {
|
||||||
|
t.Fatalf("expected interrupt error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalReturnsNilForUnknownType(t *testing.T) {
|
||||||
|
if err := signal('x'); err != nil {
|
||||||
|
t.Fatalf("expected nil error for unknown signal, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -3,20 +3,6 @@
|
|||||||
|
|
||||||
package stario
|
package stario
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
func signal(sigtype rune) error {
|
func signal(sigtype rune) error {
|
||||||
//todo: use win32api call signal
|
return signalErrorForType(sigtype)
|
||||||
switch sigtype {
|
|
||||||
case 0x03:
|
|
||||||
return errors.New("SIGNAL SIGINT RECIVED")
|
|
||||||
case 0x1a:
|
|
||||||
return errors.New("SIGNAL SIGSTOP RECIVED")
|
|
||||||
case 0x1c:
|
|
||||||
return errors.New("SIGNAL SIGQUIT RECIVED")
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,24 +3,6 @@
|
|||||||
|
|
||||||
package stario
|
package stario
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
func signal(sigtype rune) error {
|
func signal(sigtype rune) error {
|
||||||
switch sigtype {
|
return signalErrorForType(sigtype)
|
||||||
case 0x03:
|
|
||||||
syscall.Kill(os.Getpid(), syscall.SIGINT)
|
|
||||||
return errors.New("SIGNAL SIGINT RECIVED")
|
|
||||||
case 0x1a:
|
|
||||||
syscall.Kill(os.Getpid(), syscall.SIGSTOP)
|
|
||||||
return errors.New("SIGNAL SIGSTOP RECIVED")
|
|
||||||
case 0x1c:
|
|
||||||
syscall.Kill(os.Getpid(), syscall.SIGQUIT)
|
|
||||||
return errors.New("SIGNAL SIGQUIT RECIVED")
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
123
starring.go
Normal file
123
starring.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrStarRingInvalidCapacity = errors.New("star ring capacity must be greater than zero")
|
||||||
|
|
||||||
|
// StarRing keeps the newest bytes in memory with fixed capacity.
|
||||||
|
// It never blocks writers: when capacity is exceeded, oldest bytes are dropped.
|
||||||
|
type StarRing struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cap int
|
||||||
|
|
||||||
|
size int
|
||||||
|
chunks list.List // each element: []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStarRing(capacity int) (*StarRing, error) {
|
||||||
|
if capacity <= 0 {
|
||||||
|
return nil, ErrStarRingInvalidCapacity
|
||||||
|
}
|
||||||
|
return &StarRing{cap: capacity}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) Capacity() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) Len() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) Reset() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.size = 0
|
||||||
|
s.chunks.Init()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) Write(p []byte) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
payload := append([]byte(nil), p...)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.writeLocked(payload)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) WriteString(text string) (int, error) {
|
||||||
|
if len(text) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return s.Write([]byte(text))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) Snapshot() []byte {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
out := make([]byte, s.size)
|
||||||
|
pos := 0
|
||||||
|
for node := s.chunks.Front(); node != nil; node = node.Next() {
|
||||||
|
chunk, ok := node.Value.([]byte)
|
||||||
|
if !ok || len(chunk) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pos += copy(out[pos:], chunk)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) writeLocked(payload []byte) {
|
||||||
|
if len(payload) >= s.cap {
|
||||||
|
payload = payload[len(payload)-s.cap:]
|
||||||
|
s.chunks.Init()
|
||||||
|
s.chunks.PushBack(payload)
|
||||||
|
s.size = len(payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.chunks.PushBack(payload)
|
||||||
|
s.size += len(payload)
|
||||||
|
s.trimLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarRing) trimLocked() {
|
||||||
|
if s.size <= s.cap {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
overflow := s.size - s.cap
|
||||||
|
for overflow > 0 {
|
||||||
|
front := s.chunks.Front()
|
||||||
|
if front == nil {
|
||||||
|
s.size = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
head, ok := front.Value.([]byte)
|
||||||
|
if !ok || len(head) == 0 {
|
||||||
|
s.chunks.Remove(front)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(head) <= overflow {
|
||||||
|
s.chunks.Remove(front)
|
||||||
|
s.size -= len(head)
|
||||||
|
overflow -= len(head)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
trimmed := append([]byte(nil), head[overflow:]...)
|
||||||
|
front.Value = trimmed
|
||||||
|
s.size -= overflow
|
||||||
|
overflow = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
104
starring_test.go
Normal file
104
starring_test.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewStarRingRejectsInvalidCapacity(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(0)
|
||||||
|
if err != ErrStarRingInvalidCapacity {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if ring != nil {
|
||||||
|
t.Fatal("expected nil ring when capacity is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRingWriteAndSnapshot(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(16)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n, err := ring.Write([]byte("abc")); err != nil || n != 3 {
|
||||||
|
t.Fatalf("write failed: n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
if n, err := ring.WriteString("def"); err != nil || n != 3 {
|
||||||
|
t.Fatalf("write string failed: n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := ring.Snapshot()
|
||||||
|
if !bytes.Equal(got, []byte("abcdef")) {
|
||||||
|
t.Fatalf("unexpected snapshot: %q", got)
|
||||||
|
}
|
||||||
|
if ring.Len() != 6 {
|
||||||
|
t.Fatalf("unexpected ring len: %d", ring.Len())
|
||||||
|
}
|
||||||
|
if ring.Capacity() != 16 {
|
||||||
|
t.Fatalf("unexpected ring capacity: %d", ring.Capacity())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRingOverflowDropsOldestBytesPrecisely(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, _ = ring.WriteString("abcde")
|
||||||
|
_, _ = ring.WriteString("fghij")
|
||||||
|
_, _ = ring.WriteString("kl")
|
||||||
|
|
||||||
|
got := ring.Snapshot()
|
||||||
|
want := []byte("cdefghijkl")
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected snapshot after overflow: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if ring.Len() != len(want) {
|
||||||
|
t.Fatalf("unexpected ring len after overflow: %d", ring.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRingLargeWriteKeepsTail(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, _ = ring.WriteString("123")
|
||||||
|
_, _ = ring.WriteString("456789")
|
||||||
|
|
||||||
|
got := ring.Snapshot()
|
||||||
|
want := []byte("56789")
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected snapshot after large write: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRingSnapshotIsCopy(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(8)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, _ = ring.WriteString("hello")
|
||||||
|
|
||||||
|
first := ring.Snapshot()
|
||||||
|
first[0] = 'H'
|
||||||
|
second := ring.Snapshot()
|
||||||
|
if !bytes.Equal(second, []byte("hello")) {
|
||||||
|
t.Fatalf("snapshot should not expose internal storage: %q", second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRingReset(t *testing.T) {
|
||||||
|
ring, err := NewStarRing(8)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, _ = ring.WriteString("hello")
|
||||||
|
ring.Reset()
|
||||||
|
if ring.Len() != 0 {
|
||||||
|
t.Fatalf("expected len=0 after reset, got %d", ring.Len())
|
||||||
|
}
|
||||||
|
if got := ring.Snapshot(); len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty snapshot after reset, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
17
sync.go
17
sync.go
@ -12,6 +12,10 @@ const (
|
|||||||
waitGroupAddModeLoose
|
waitGroupAddModeLoose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WaitGroup is a concurrency-limited sync.WaitGroup variant.
|
||||||
|
//
|
||||||
|
// A zero or negative limit means unlimited concurrency. WaitGroup must not be
|
||||||
|
// copied after first use.
|
||||||
type WaitGroup struct {
|
type WaitGroup struct {
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -22,6 +26,7 @@ type WaitGroup struct {
|
|||||||
addMode waitGroupAddMode
|
addMode waitGroupAddMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewWaitGroup creates a WaitGroup with the provided concurrency limit.
|
||||||
func NewWaitGroup(maxCount int) WaitGroup {
|
func NewWaitGroup(maxCount int) WaitGroup {
|
||||||
if maxCount < 0 {
|
if maxCount < 0 {
|
||||||
panic("stario: negative max wait count")
|
panic("stario: negative max wait count")
|
||||||
@ -38,6 +43,10 @@ func (w *WaitGroup) init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add adjusts the running task count.
|
||||||
|
//
|
||||||
|
// Positive deltas may block when the concurrency limit is reached. Negative
|
||||||
|
// deltas release running slots.
|
||||||
func (w *WaitGroup) Add(delta int) {
|
func (w *WaitGroup) Add(delta int) {
|
||||||
w.init()
|
w.init()
|
||||||
if delta == 0 {
|
if delta == 0 {
|
||||||
@ -87,10 +96,12 @@ func (w *WaitGroup) release(delta int) {
|
|||||||
w.cond.Broadcast()
|
w.cond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Done releases one running task slot.
|
||||||
func (w *WaitGroup) Done() {
|
func (w *WaitGroup) Done() {
|
||||||
w.Add(-1)
|
w.Add(-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Go runs fn in a goroutine while accounting for the concurrency limit.
|
||||||
func (w *WaitGroup) Go(fn func()) {
|
func (w *WaitGroup) Go(fn func()) {
|
||||||
w.Add(1)
|
w.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -99,11 +110,13 @@ func (w *WaitGroup) Go(fn func()) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait blocks until all added work has completed.
|
||||||
func (w *WaitGroup) Wait() {
|
func (w *WaitGroup) Wait() {
|
||||||
w.init()
|
w.init()
|
||||||
w.wg.Wait()
|
w.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMaxWaitNum returns the current concurrency limit.
|
||||||
func (w *WaitGroup) GetMaxWaitNum() int {
|
func (w *WaitGroup) GetMaxWaitNum() int {
|
||||||
w.init()
|
w.init()
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
@ -111,6 +124,7 @@ func (w *WaitGroup) GetMaxWaitNum() int {
|
|||||||
return w.maxCount
|
return w.maxCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMaxWaitNum updates the concurrency limit.
|
||||||
func (w *WaitGroup) SetMaxWaitNum(num int) {
|
func (w *WaitGroup) SetMaxWaitNum(num int) {
|
||||||
if num < 0 {
|
if num < 0 {
|
||||||
panic("stario: negative max wait count")
|
panic("stario: negative max wait count")
|
||||||
@ -122,6 +136,8 @@ func (w *WaitGroup) SetMaxWaitNum(num int) {
|
|||||||
w.cond.Broadcast()
|
w.cond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStrictAddMode controls whether Add(n>1) panics or auto-expands the limit
|
||||||
|
// when the requested batch exceeds the current capacity.
|
||||||
func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
||||||
w.init()
|
w.init()
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
@ -134,6 +150,7 @@ func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
|||||||
w.cond.Broadcast()
|
w.cond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StrictAddMode reports whether strict batch-add behavior is enabled.
|
||||||
func (w *WaitGroup) StrictAddMode() bool {
|
func (w *WaitGroup) StrictAddMode() bool {
|
||||||
w.init()
|
w.init()
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
|||||||
73
terminal.go
Normal file
73
terminal.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrTerminalNotTTY reports that terminal-only input was requested from a
|
||||||
|
// non-terminal stdin.
|
||||||
|
var ErrTerminalNotTTY = errors.New("terminal input requires a tty")
|
||||||
|
|
||||||
|
var terminalCheckFunc = term.IsTerminal
|
||||||
|
|
||||||
|
// IsTerminal reports whether os.Stdin is attached to a terminal.
|
||||||
|
func IsTerminal() bool {
|
||||||
|
return IsTerminalFile(os.Stdin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminalFD reports whether fd is attached to a terminal.
|
||||||
|
func IsTerminalFD(fd int) bool {
|
||||||
|
return terminalCheckFunc(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminalFile reports whether file is attached to a terminal.
|
||||||
|
func IsTerminalFile(file *os.File) bool {
|
||||||
|
if file == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return IsTerminalFD(int(file.Fd()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPasswordContext reads one password-style line from the terminal.
|
||||||
|
//
|
||||||
|
// It returns ErrTerminalNotTTY when stdin is not a terminal. If ctx is canceled
|
||||||
|
// while waiting for input, this function returns ctx.Err() without waiting for
|
||||||
|
// the underlying terminal read to finish.
|
||||||
|
func ReadPasswordContext(ctx context.Context, hint string, defaultVal string) (string, error) {
|
||||||
|
return ReadPasswordContextWithMask(ctx, hint, defaultVal, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPasswordContextWithMask is like ReadPasswordContext but echoes the given
|
||||||
|
// mask string while the user types.
|
||||||
|
func ReadPasswordContextWithMask(ctx context.Context, hint string, defaultVal string, mask string) (string, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !IsTerminal() {
|
||||||
|
return "", ErrTerminalNotTTY
|
||||||
|
}
|
||||||
|
session, err := rawTerminalSessionFactory(hint, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resultCh := make(chan InputMsg, 1)
|
||||||
|
go func() {
|
||||||
|
defer session.Close()
|
||||||
|
resultCh <- rawLineInputSession(session, defaultVal, mask, rawInputSignalReturnError)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.String()
|
||||||
|
case <-ctx.Done():
|
||||||
|
session.Abort()
|
||||||
|
<-resultCh
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
77
terminal_test.go
Normal file
77
terminal_test.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package stario
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func installTerminalCheckStub(t *testing.T, ok bool) {
|
||||||
|
t.Helper()
|
||||||
|
prev := terminalCheckFunc
|
||||||
|
terminalCheckFunc = func(fd int) bool {
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
terminalCheckFunc = prev
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTerminalFDUsesStub(t *testing.T) {
|
||||||
|
installTerminalCheckStub(t, true)
|
||||||
|
if !IsTerminalFD(1) {
|
||||||
|
t.Fatal("expected terminal check to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPasswordContextRejectsNonTTY(t *testing.T) {
|
||||||
|
installTerminalCheckStub(t, false)
|
||||||
|
_, err := ReadPasswordContext(context.Background(), "pwd: ", "")
|
||||||
|
if !errors.Is(err, ErrTerminalNotTTY) {
|
||||||
|
t.Fatalf("expected ErrTerminalNotTTY, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPasswordContextReadsValue(t *testing.T) {
|
||||||
|
installTerminalCheckStub(t, true)
|
||||||
|
installRawInputStub(t, "secret\n", nil)
|
||||||
|
|
||||||
|
got, err := ReadPasswordContext(context.Background(), "pwd: ", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadPasswordContext returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "secret" {
|
||||||
|
t.Fatalf("unexpected password: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPasswordContextCanceledWhileWaiting(t *testing.T) {
|
||||||
|
installTerminalCheckStub(t, true)
|
||||||
|
|
||||||
|
prevFactory := rawTerminalSessionFactory
|
||||||
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
rawTerminalSessionFactory = func(hint string, printNewline bool) (*rawTerminalSession, error) {
|
||||||
|
return &rawTerminalSession{
|
||||||
|
input: pipeReader,
|
||||||
|
reader: bufio.NewReader(pipeReader),
|
||||||
|
redrawHint: strings.TrimSpace(hint),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
rawTerminalSessionFactory = prevFactory
|
||||||
|
_ = pipeWriter.Close()
|
||||||
|
_ = pipeReader.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := ReadPasswordContext(ctx, "pwd: ", "")
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("expected deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
10
tty_other.go
Normal file
10
tty_other.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package stario
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
func openRawTerminalInput() (*os.File, error) {
|
||||||
|
return os.OpenFile("/dev/tty", os.O_RDWR, 0)
|
||||||
|
}
|
||||||
10
tty_windows.go
Normal file
10
tty_windows.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package stario
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
func openRawTerminalInput() (*os.File, error) {
|
||||||
|
return os.OpenFile("CONIN$", os.O_RDWR, 0)
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user