Compare commits

..

4 Commits

11 changed files with 1171 additions and 255 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
.sentrux/
agent_readme.md
target.md
.gocache
.gocache/
.tmp_*/
.idea

231
circle.go
View File

@ -2,148 +2,191 @@ package stario
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"os"
"runtime"
"sync" "sync"
"sync/atomic"
"time"
) )
var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be greater than zero")
var ErrStarBufferClosed = errors.New("star buffer closed")
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
type StarBuffer struct { type StarBuffer struct {
io.Reader datas []byte
io.Writer pStart uint64
io.Closer pEnd uint64
datas []byte size uint64
pStart uint64 cap uint64
pEnd uint64 isClose bool
cap uint64 isWriteEnd bool
isClose atomic.Value mu sync.Mutex
isEnd atomic.Value notEmpty *sync.Cond
rmu sync.Mutex notFull *sync.Cond
wmu sync.Mutex
} }
func NewStarBuffer(cap uint64) *StarBuffer { func NewStarBuffer(cap uint64) (*StarBuffer, error) {
rtnBuffer := new(StarBuffer) if cap == 0 {
rtnBuffer.cap = cap return nil, ErrStarBufferInvalidCapacity
rtnBuffer.datas = make([]byte, cap) }
rtnBuffer.isClose.Store(false) rtnBuffer := &StarBuffer{
rtnBuffer.isEnd.Store(false) cap: cap,
return rtnBuffer datas: make([]byte, cap),
}
rtnBuffer.notEmpty = sync.NewCond(&rtnBuffer.mu)
rtnBuffer.notFull = sync.NewCond(&rtnBuffer.mu)
return rtnBuffer, nil
} }
func (star *StarBuffer) Free() uint64 { func (star *StarBuffer) Free() uint64 {
return star.cap - star.Len() star.mu.Lock()
defer star.mu.Unlock()
return star.cap - star.size
} }
func (star *StarBuffer) Cap() uint64 { func (star *StarBuffer) Cap() uint64 {
star.mu.Lock()
defer star.mu.Unlock()
return star.cap return star.cap
} }
func (star *StarBuffer) Len() uint64 { func (star *StarBuffer) Len() uint64 {
if star.pEnd >= star.pStart { star.mu.Lock()
return star.pEnd - star.pStart defer star.mu.Unlock()
} return star.size
return star.pEnd - star.pStart + star.cap
} }
func (star *StarBuffer) getByte() (byte, error) { func (star *StarBuffer) getByte() (byte, error) {
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) { star.mu.Lock()
defer star.mu.Unlock()
for star.size == 0 && !star.isWriteEnd && !star.isClose {
star.notEmpty.Wait()
}
if star.size == 0 {
return 0, io.EOF return 0, io.EOF
} }
if star.Len() == 0 { data := star.datas[star.pStart]
return 0, os.ErrNotExist star.pStart = (star.pStart + 1) % star.cap
} star.size--
nowPtr := star.pStart star.notFull.Broadcast()
nextPtr := star.pStart + 1
if nextPtr >= star.cap {
nextPtr = 0
}
data := star.datas[nowPtr]
ok := atomic.CompareAndSwapUint64(&star.pStart, nowPtr, nextPtr)
if !ok {
return 0, os.ErrInvalid
}
return data, nil return data, nil
} }
func (star *StarBuffer) putByte(data byte) error { func (star *StarBuffer) putByte(data byte) error {
if star.isClose.Load().(bool) { star.mu.Lock()
return io.EOF defer star.mu.Unlock()
for star.size == star.cap && !star.isClose && !star.isWriteEnd {
star.notFull.Wait()
} }
nowPtr := star.pEnd if star.isClose {
kariEnd := nowPtr + 1 return ErrStarBufferClosed
if kariEnd == star.cap {
kariEnd = 0
} }
if kariEnd == atomic.LoadUint64(&star.pStart) { if star.isWriteEnd {
for { return ErrStarBufferWriteClosed
time.Sleep(time.Microsecond)
runtime.Gosched()
if kariEnd != atomic.LoadUint64(&star.pStart) {
break
}
}
} }
star.datas[nowPtr] = data star.datas[star.pEnd] = data
if ok := atomic.CompareAndSwapUint64(&star.pEnd, nowPtr, kariEnd); !ok { star.pEnd = (star.pEnd + 1) % star.cap
return os.ErrInvalid star.size++
star.notEmpty.Broadcast()
return nil
}
func (star *StarBuffer) EndWrite() error {
star.mu.Lock()
defer star.mu.Unlock()
if star.isClose {
return ErrStarBufferClosed
} }
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil return nil
} }
func (star *StarBuffer) Close() error { func (star *StarBuffer) Close() error {
star.isClose.Store(true) star.mu.Lock()
defer star.mu.Unlock()
star.isClose = true
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil return nil
} }
func (star *StarBuffer) Read(buf []byte) (int, error) { func (star *StarBuffer) Read(buf []byte) (int, error) {
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) {
return 0, io.EOF
}
if buf == nil { if buf == nil {
return 0, errors.New("buffer is nil") return 0, errors.New("buffer is nil")
} }
star.rmu.Lock() if len(buf) == 0 {
defer star.rmu.Unlock() return 0, nil
var sum int = 0
for i := 0; i < len(buf); i++ {
data, err := star.getByte()
if err != nil {
if err == io.EOF {
return sum, err
}
if err == os.ErrNotExist {
i--
continue
}
return sum, nil
}
buf[i] = data
sum++
} }
star.mu.Lock()
defer star.mu.Unlock()
for star.size == 0 && !star.isWriteEnd && !star.isClose {
star.notEmpty.Wait()
}
if star.size == 0 {
return 0, io.EOF
}
sum := minInt(len(buf), int(star.size))
first := minInt(sum, int(star.cap-star.pStart))
copy(buf, star.datas[star.pStart:star.pStart+uint64(first)])
second := sum - first
if second > 0 {
copy(buf[first:], star.datas[:second])
}
star.pStart = (star.pStart + uint64(sum)) % star.cap
star.size -= uint64(sum)
star.notFull.Broadcast()
return sum, nil return sum, nil
} }
func (star *StarBuffer) Write(bts []byte) (int, error) { func (star *StarBuffer) Write(bts []byte) (int, error) {
if bts == nil && !star.isEnd.Load().(bool) { if bts == nil {
star.isEnd.Store(true) return 0, star.EndWrite()
}
if len(bts) == 0 {
return 0, nil return 0, nil
} }
if bts == nil || star.isClose.Load().(bool) { star.mu.Lock()
return 0, io.EOF defer star.mu.Unlock()
} sum := 0
star.wmu.Lock() for sum < len(bts) {
defer star.wmu.Unlock() for star.size == star.cap && !star.isClose && !star.isWriteEnd {
var sum = 0 star.notFull.Wait()
for i := 0; i < len(bts); i++ {
err := star.putByte(bts[i])
if err != nil {
fmt.Println("Write bts err:", err)
return sum, err
} }
sum++ if star.isClose {
if sum == 0 {
return 0, ErrStarBufferClosed
}
return sum, ErrStarBufferClosed
}
if star.isWriteEnd {
if sum == 0 {
return 0, ErrStarBufferWriteClosed
}
return sum, ErrStarBufferWriteClosed
}
space := int(star.cap - star.size)
if space == 0 {
continue
}
n := minInt(len(bts)-sum, space)
first := minInt(n, int(star.cap-star.pEnd))
copy(star.datas[star.pEnd:star.pEnd+uint64(first)], bts[sum:sum+first])
second := n - first
if second > 0 {
copy(star.datas[:second], bts[sum+first:sum+n])
}
star.pEnd = (star.pEnd + uint64(n)) % star.cap
star.size += uint64(n)
sum += n
star.notEmpty.Broadcast()
} }
return sum, nil return sum, nil
} }
func minInt(a int, b int) int {
if a < b {
return a
}
return b
}

View File

@ -1,14 +1,91 @@
package stario package stario
import ( import (
"bytes"
"fmt" "fmt"
"io"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
) )
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
buf, err := NewStarBuffer(0)
if err != ErrStarBufferInvalidCapacity {
t.Fatalf("unexpected error: %v", err)
}
if buf != nil {
t.Fatal("expected nil buffer when capacity is invalid")
}
}
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
buf, err := NewStarBuffer(4)
if err != nil {
t.Fatal(err)
}
if _, err := buf.Write([]byte("abcd")); err != nil {
t.Fatal(err)
}
if err := buf.EndWrite(); err != nil {
t.Fatal(err)
}
got := make([]byte, 4)
n, err := buf.Read(got)
if err != nil {
t.Fatal(err)
}
if n != 4 || !bytes.Equal(got[:n], []byte("abcd")) {
t.Fatalf("unexpected payload: n=%d data=%q", n, got[:n])
}
n, err = buf.Read(got)
if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after draining buffer, got n=%d err=%v", n, err)
}
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
t.Fatalf("unexpected write error after EndWrite: %v", err)
}
}
func TestStarBufferCloseAllowsDrain(t *testing.T) {
buf, err := NewStarBuffer(4)
if err != nil {
t.Fatal(err)
}
if _, err := buf.Write([]byte("ab")); err != nil {
t.Fatal(err)
}
if err := buf.Close(); err != nil {
t.Fatal(err)
}
got := make([]byte, 2)
n, err := buf.Read(got)
if err != nil {
t.Fatal(err)
}
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n])
}
n, err = buf.Read(got)
if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err)
}
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
t.Fatalf("unexpected write error after Close: %v", err)
}
}
func Test_Circle(t *testing.T) { func Test_Circle(t *testing.T) {
buf := NewStarBuffer(2048) buf, err := NewStarBuffer(2048)
if err != nil {
t.Fatal(err)
}
go func() { go func() {
for { for {
//fmt.Println("write start") //fmt.Println("write start")
@ -37,7 +114,10 @@ func Test_Circle(t *testing.T) {
} }
func Test_Circle_Speed(t *testing.T) { func Test_Circle_Speed(t *testing.T) {
buf := NewStarBuffer(1048976) buf, err := NewStarBuffer(1048976)
if err != nil {
t.Fatal(err)
}
count := uint64(0) count := uint64(0)
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
go func() { go func() {
@ -61,7 +141,10 @@ func Test_Circle_Speed(t *testing.T) {
} }
func Test_Circle_Speed2(t *testing.T) { func Test_Circle_Speed2(t *testing.T) {
buf := NewStarBuffer(8192) buf, err := NewStarBuffer(8192)
if err != nil {
t.Fatal(err)
}
count := uint64(0) count := uint64(0)
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
go func() { go func() {

2
go.mod
View File

@ -2,4 +2,4 @@ module b612.me/stario
go 1.16 go 1.16
require golang.org/x/crypto v0.21.0 require golang.org/x/crypto v0.26.0

34
go.sum
View File

@ -1,20 +1,32 @@
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -22,24 +34,34 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

317
io.go
View File

@ -16,12 +16,20 @@ type InputMsg struct {
skipSliceSigErr bool skipSliceSigErr bool
} }
type rawTerminalSession struct {
fd int
state *terminal.State
reader *bufio.Reader
redrawHint string
printNewline bool
}
func Passwd(hint string, defaultVal string) InputMsg { func Passwd(hint string, defaultVal string) InputMsg {
return passwd(hint, defaultVal, "", false) return passwd(hint, defaultVal, "", true)
} }
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg { func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
return passwd(hint, defaultVal, mask, false) return passwd(hint, defaultVal, mask, true)
} }
func PasswdResponseSignal(hint string, defaultVal string) InputMsg { func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
@ -36,54 +44,162 @@ func MessageBoxRaw(hint string, defaultVal string) InputMsg {
return messageBox(hint, defaultVal) return messageBox(hint, defaultVal)
} }
func messageBox(hint string, defaultVal string) InputMsg { func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession, error) {
var ioBuf []rune
if hint != "" { if hint != "" {
fmt.Print(hint) fmt.Print(hint)
} }
if strings.Index(hint, "\n") >= 0 {
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
}
fd := int(os.Stdin.Fd()) fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd) state, err := terminal.MakeRaw(fd)
if err != nil { if err != nil {
return InputMsg{msg: "", err: err} return nil, err
} }
defer fmt.Println() return &rawTerminalSession{
defer terminal.Restore(fd, state) fd: fd,
inputReader := bufio.NewReader(os.Stdin) state: state,
for { reader: bufio.NewReader(os.Stdin),
b, _, err := inputReader.ReadRune() redrawHint: promptRedrawHint(hint),
if err != nil { printNewline: printNewline,
return InputMsg{msg: "", err: err} }, nil
} }
if b == 0x0d {
strValue := strings.TrimSpace(string(ioBuf)) func (session *rawTerminalSession) Close() {
if len(strValue) == 0 { if session == nil || session.state == nil {
strValue = defaultVal return
} }
return InputMsg{msg: strValue, err: err} _ = terminal.Restore(session.fd, session.state)
} if session.printNewline {
if b == 0x08 || b == 0x7F { fmt.Println()
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
fmt.Print("\r")
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
fmt.Print(" ")
}
} else {
ioBuf = append(ioBuf, b)
}
fmt.Print("\r")
if hint != "" {
fmt.Print(hint)
}
fmt.Print(string(ioBuf))
} }
} }
func isSiganl(s rune) bool { func (session *rawTerminalSession) Restore() error {
if session == nil || session.state == nil {
return nil
}
return terminal.Restore(session.fd, session.state)
}
func promptRedrawHint(hint string) string {
if strings.Index(hint, "\n") >= 0 {
hint = hint[strings.LastIndex(hint, "\n")+1:]
}
return strings.TrimSpace(hint)
}
func finalizeInputValue(raw string, defaultVal string) string {
raw = strings.TrimSpace(raw)
if len(raw) == 0 {
return defaultVal
}
return raw
}
func renderRawEcho(ioBuf []rune, mask string) string {
if mask == "" {
return string(ioBuf)
}
return strings.Repeat(mask, len(ioBuf))
}
func redrawPromptLine(hint string, echo string, lastWidth int) int {
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
clearWidth := lastWidth
if nowWidth > clearWidth {
clearWidth = nowWidth
}
fmt.Print("\r")
if clearWidth > 0 {
fmt.Print(strings.Repeat(" ", clearWidth))
fmt.Print("\r")
}
if hint != "" {
fmt.Print(hint)
}
if echo != "" {
fmt.Print(echo)
}
return nowWidth
}
func stringDisplayWidth(text string) int {
width := 0
for _, r := range text {
width += runeDisplayWidth(r)
}
return width
}
func runeDisplayWidth(r rune) int {
switch {
case r == '\t':
return 4
case r < 0x20 || (r >= 0x7f && r < 0xa0):
return 0
case isWideRune(r):
return 2
default:
return 1
}
}
func isWideRune(r rune) bool {
return r >= 0x1100 && (r <= 0x115f ||
r == 0x2329 || r == 0x232a ||
(r >= 0x2e80 && r <= 0xa4cf && r != 0x303f) ||
(r >= 0xac00 && r <= 0xd7a3) ||
(r >= 0xf900 && r <= 0xfaff) ||
(r >= 0xfe10 && r <= 0xfe19) ||
(r >= 0xfe30 && r <= 0xfe6f) ||
(r >= 0xff00 && r <= 0xff60) ||
(r >= 0xffe0 && r <= 0xffe6) ||
(r >= 0x1f300 && r <= 0x1f64f) ||
(r >= 0x1f900 && r <= 0x1f9ff) ||
(r >= 0x20000 && r <= 0x3fffd))
}
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
session, err := newRawTerminalSession(hint, true)
if err != nil {
return InputMsg{msg: "", err: err}
}
defer session.Close()
ioBuf := make([]rune, 0, 16)
lastWidth := 0
for {
b, _, err := session.reader.ReadRune()
if err != nil {
return InputMsg{msg: "", err: err}
}
if handleSignal && isSignal(b) {
if runtime.GOOS != "windows" {
if err := session.Restore(); err != nil {
return InputMsg{msg: "", err: err}
}
}
if err := signal(b); err != nil {
return InputMsg{msg: "", err: err}
}
continue
}
switch b {
case 0x0d, 0x0a:
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
case 0x08, 0x7F:
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
default:
ioBuf = append(ioBuf, b)
}
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth)
}
}
func messageBox(hint string, defaultVal string) InputMsg {
return rawLineInput(hint, defaultVal, "", false)
}
func isSignal(s rune) bool {
switch s { switch s {
case 0x03, 0x1a, 0x1c: case 0x03, 0x1a, 0x1c:
return true return true
@ -93,63 +209,7 @@ func isSiganl(s rune) bool {
} }
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
var ioBuf []rune return rawLineInput(hint, defaultVal, mask, handleSignal)
if hint != "" {
fmt.Print(hint)
}
if strings.Index(hint, "\n") >= 0 {
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
}
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return InputMsg{msg: "", err: err}
}
defer fmt.Println()
defer terminal.Restore(fd, state)
inputReader := bufio.NewReader(os.Stdin)
for {
b, _, err := inputReader.ReadRune()
if err != nil {
return InputMsg{msg: "", err: err}
}
if handleSignal && isSiganl(b) {
if runtime.GOOS != "windows" {
terminal.Restore(fd, state)
}
if err := signal(b); err != nil {
return InputMsg{
msg: "",
err: err,
}
}
}
if b == 0x0d {
strValue := strings.TrimSpace(string(ioBuf))
if len(strValue) == 0 {
strValue = defaultVal
}
return InputMsg{msg: strValue, err: err}
}
if b == 0x08 || b == 0x7F {
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
fmt.Print("\r")
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
fmt.Print(" ")
}
} else {
ioBuf = append(ioBuf, b)
}
fmt.Print("\r")
if hint != "" {
fmt.Print(hint)
}
for i := 0; i < len(ioBuf); i++ {
fmt.Print(mask)
}
}
} }
func MessageBox(hint string, defaultVal string) InputMsg { func MessageBox(hint string, defaultVal string) InputMsg {
@ -161,14 +221,11 @@ func MessageBox(hint string, defaultVal string) InputMsg {
if err != nil { if err != nil {
return InputMsg{msg: str, err: err} return InputMsg{msg: str, err: err}
} }
str = strings.TrimSpace(str) str = finalizeInputValue(str, defaultVal)
if len(str) == 0 {
str = defaultVal
}
return InputMsg{msg: str, err: err} return InputMsg{msg: str, err: err}
} }
func (im InputMsg) IgnoreSliceParseError(i bool) InputMsg { func (im *InputMsg) IgnoreSliceParseError(i bool) *InputMsg {
im.skipSliceSigErr = i im.skipSliceSigErr = i
return im return im
} }
@ -189,6 +246,9 @@ func (im InputMsg) SliceString(sep string) ([]string, error) {
if im.err != nil { if im.err != nil {
return nil, im.err return nil, im.err
} }
if len(strings.TrimSpace(im.msg)) == 0 {
return []string{}, nil
}
return strings.Split(im.msg, sep), nil return strings.Split(im.msg, sep), nil
} }
@ -383,46 +443,61 @@ func (im InputMsg) MustSliceFloat32(sep string) []float32 {
} }
func YesNo(hint string, defaults bool) bool { func YesNo(hint string, defaults bool) bool {
res, err := YesNoE(hint, defaults)
if err != nil {
return defaults
}
return res
}
func parseYesNoValue(raw string, defaults bool) (bool, bool) {
raw = strings.TrimSpace(strings.ToUpper(raw))
if raw == "" {
return defaults, true
}
switch []rune(raw)[0] {
case 'Y':
return true, true
case 'N':
return false, true
default:
return false, false
}
}
func YesNoE(hint string, defaults bool) (bool, error) {
for { for {
res := strings.ToUpper(MessageBox(hint, "").MustString()) res, err := MessageBox(hint, "").String()
if res == "" { if err != nil {
return defaults return false, err
} }
res = res[0:1] if answer, ok := parseYesNoValue(res, defaults); ok {
if res == "Y" { return answer, nil
return true
} else if res == "N" {
return false
} }
} }
} }
func StopUntil(hint string, trigger string, repeat bool) error { func StopUntil(hint string, trigger string, repeat bool) error {
pressLen := len([]rune(trigger)) triggerRunes := []rune(trigger)
pressLen := len(triggerRunes)
if trigger == "" { if trigger == "" {
pressLen = 1 pressLen = 1
} }
fd := int(os.Stdin.Fd()) session, err := newRawTerminalSession(hint, false)
if hint != "" {
fmt.Print(hint)
}
state, err := terminal.MakeRaw(fd)
if err != nil { if err != nil {
return err return err
} }
defer terminal.Restore(fd, state) defer session.Close()
inputReader := bufio.NewReader(os.Stdin)
//ioBuf := make([]byte, pressLen)
i := 0 i := 0
for { for {
b, _, err := inputReader.ReadRune() b, _, err := session.reader.ReadRune()
if err != nil { if err != nil {
return err return err
} }
if trigger == "" { if trigger == "" {
break break
} }
if b == []rune(trigger)[i] { if b == triggerRunes[i] {
i++ i++
if i == pressLen { if i == pressLen {
break break

View File

@ -5,6 +5,41 @@ import (
"testing" "testing"
) )
func TestPromptRedrawHint(t *testing.T) {
got := promptRedrawHint("头部提示\n 中文确认: ")
if got != "中文确认:" {
t.Fatalf("unexpected redraw hint: got %q", got)
}
}
func TestStringDisplayWidth(t *testing.T) {
got := stringDisplayWidth("中a[]")
if got != 5 {
t.Fatalf("unexpected display width: got %d want 5", got)
}
}
func TestParseYesNoValue(t *testing.T) {
cases := []struct {
name string
input string
defaults bool
want bool
ok bool
}{
{name: "default", input: " ", defaults: true, want: true, ok: true},
{name: "yes", input: "yes", defaults: false, want: true, ok: true},
{name: "no", input: "No", defaults: true, want: false, ok: true},
{name: "invalid", input: "maybe", defaults: false, want: false, ok: false},
}
for _, tc := range cases {
got, ok := parseYesNoValue(tc.input, tc.defaults)
if got != tc.want || ok != tc.ok {
t.Fatalf("%s: got (%v, %v) want (%v, %v)", tc.name, got, ok, tc.want, tc.ok)
}
}
}
func Test_Slice(t *testing.T) { func Test_Slice(t *testing.T) {
var data = InputMsg{ var data = InputMsg{
msg: "true,false,true,true,false,0,1,hello", msg: "true,false,true,true,false,0,1,hello",
@ -21,3 +56,29 @@ func Test_Slice(t *testing.T) {
} }
fmt.Println(res) fmt.Println(res)
} }
func TestSliceMsg(t *testing.T) {
var data = InputMsg{
msg: "",
err: nil,
skipSliceSigErr: false,
}
res, err := data.SliceString(",")
if err != nil {
fmt.Println(res)
t.Fatal(err)
}
if len(res) != 0 {
t.Fatal(res)
}
fmt.Println(len(res))
res2, err := data.SliceInt64(",")
if err != nil {
fmt.Println(res2)
t.Fatal(err)
}
if len(res2) != 0 {
t.Fatal(res2)
}
fmt.Println(len(res2))
}

325
que.go Normal file
View File

@ -0,0 +1,325 @@
package stario
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"sync"
"time"
)
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
maxLength uint32
count int64
Encode bool
msgID uint16
msgPool chan MsgQueue
unFinMsg sync.Map
lastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
var q StarQueue
q.Encode = false
q.count = count
q.maxLength = maxMsgLength
q.msgPool = make(chan MsgQueue, count)
if ctx == nil {
q.ctx, q.cancel = context.WithCancel(context.Background())
} else {
q.ctx, q.cancel = context.WithCancel(ctx)
}
q.duration = 0
return &q
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count, 0)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (q *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
q.msgID++
if q.Encode {
src = q.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (q *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
q.msgID++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (q *StarQueue) push2list(msg MsgQueue) {
q.msgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := q.unFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
q.unFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return q.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
q.unFinMsg.Delete(conn)
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
}
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
q.unFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
q.unFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//q.restoreMu.Lock()
q.push2list(storeMsg)
//q.restoreMu.Unlock()
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
q.push2list(storeMsg)
q.unFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return q.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
q.unFinMsg.Store(conn, &lastMsg)
return q.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (q *StarQueue) Restore() (MsgQueue, error) {
if q.duration.Seconds() == 0 {
q.duration = 86400 * time.Second
}
for {
select {
case <-q.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(q.duration):
if q.duration != 0 {
return MsgQueue{}, ErrDeadlineExceeded
}
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
// 兼容性修改
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
return q.Restore()
}
// Stop 立即停止Restore
func (q *StarQueue) Stop() {
q.cancel()
}
// RestoreDuration Restore最大超时时间
func (q *StarQueue) RestoreDuration(tm time.Duration) {
q.duration = tm
}
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
return q.msgPool
}

42
que_test.go Normal file
View File

@ -0,0 +1,42 @@
package stario
import (
"fmt"
"testing"
"time"
)
func Test_QueSpeed(t *testing.T) {
que := NewQueueWithCount(0)
stop := make(chan struct{}, 1)
que.RestoreDuration(time.Second * 10)
var count int64
go func() {
for {
select {
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select {
case <-stoped:
fmt.Println(count, cp)
stop <- struct{}{}
return
default:
que.ParseMessage(data, "lala")
cp++
}
}
}

147
sync.go
View File

@ -1,55 +1,142 @@
package stario package stario
import ( import (
"fmt"
"sync" "sync"
"sync/atomic" )
"time"
type waitGroupAddMode uint8
const (
waitGroupAddModeStrict waitGroupAddMode = iota
waitGroupAddModeLoose
) )
type WaitGroup struct { type WaitGroup struct {
wg *sync.WaitGroup wg sync.WaitGroup
maxCount uint32 mu sync.Mutex
allCount uint32 cond *sync.Cond
initOnce sync.Once
maxCount int
running int
addMode waitGroupAddMode
} }
func NewWaitGroup(maxCount int) WaitGroup { func NewWaitGroup(maxCount int) WaitGroup {
return WaitGroup{wg: &sync.WaitGroup{}, maxCount: uint32(maxCount)} if maxCount < 0 {
panic("stario: negative max wait count")
}
return WaitGroup{
maxCount: maxCount,
addMode: waitGroupAddModeStrict,
}
} }
func (swg *WaitGroup) Add(delta int) { func (w *WaitGroup) init() {
var Udelta uint32 w.initOnce.Do(func() {
w.cond = sync.NewCond(&w.mu)
})
}
func (w *WaitGroup) Add(delta int) {
w.init()
if delta == 0 {
return
}
if delta < 0 { if delta < 0 {
Udelta = uint32(-delta - 1) w.release(-delta)
} else { return
Udelta = uint32(delta)
} }
for { w.acquire(delta)
allC := atomic.LoadUint32(&swg.allCount) }
if atomic.LoadUint32(&swg.maxCount) == 0 || atomic.LoadUint32(&swg.maxCount) >= allC+uint32(delta) {
if delta < 0 { func (w *WaitGroup) acquire(delta int) {
atomic.AddUint32(&swg.allCount, ^uint32(Udelta)) w.mu.Lock()
} else { defer w.mu.Unlock()
atomic.AddUint32(&swg.allCount, uint32(Udelta)) if w.maxCount <= 0 {
} w.running += delta
break w.wg.Add(delta)
return
}
if delta == 1 {
w.wg.Add(1)
for w.maxCount > 0 && w.running >= w.maxCount {
w.cond.Wait()
} }
time.Sleep(time.Microsecond) w.running++
return
} }
swg.wg.Add(delta) if w.running+delta > w.maxCount {
if w.addMode == waitGroupAddModeStrict {
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
}
w.maxCount = w.running + delta
}
w.running += delta
w.wg.Add(delta)
} }
func (swg *WaitGroup) Done() { func (w *WaitGroup) release(delta int) {
swg.Add(-1) w.mu.Lock()
defer w.mu.Unlock()
if delta > w.running {
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
}
w.wg.Add(-delta)
w.running -= delta
w.cond.Broadcast()
} }
func (swg *WaitGroup) Wait() { func (w *WaitGroup) Done() {
swg.wg.Wait() w.Add(-1)
} }
func (swg *WaitGroup) GetMaxWaitNum() int { func (w *WaitGroup) Go(fn func()) {
return int(atomic.LoadUint32(&swg.maxCount)) w.Add(1)
go func() {
defer w.Done()
fn()
}()
} }
func (swg *WaitGroup) SetMaxWaitNum(num int) { func (w *WaitGroup) Wait() {
atomic.AddUint32(&swg.maxCount, uint32(num)) w.init()
w.wg.Wait()
}
func (w *WaitGroup) GetMaxWaitNum() int {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.maxCount
}
func (w *WaitGroup) SetMaxWaitNum(num int) {
if num < 0 {
panic("stario: negative max wait count")
}
w.init()
w.mu.Lock()
w.maxCount = num
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.init()
w.mu.Lock()
if strict {
w.addMode = waitGroupAddModeStrict
} else {
w.addMode = waitGroupAddModeLoose
}
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) StrictAddMode() bool {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.addMode == waitGroupAddModeStrict
} }

171
sync_test.go Normal file
View File

@ -0,0 +1,171 @@
package stario
import (
"testing"
"time"
)
func TestWaitGroupAddBlocksAtLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
unblocked := make(chan struct{})
go func() {
wg.Add(1)
close(unblocked)
wg.Done()
}()
select {
case <-unblocked:
t.Fatal("Add(1) should block while running count is at the limit")
case <-time.After(30 * time.Millisecond):
}
wg.Done()
select {
case <-unblocked:
case <-time.After(200 * time.Millisecond):
t.Fatal("Add(1) did not resume after capacity was released")
}
wg.Wait()
}
func TestWaitGroupStrictAddPanicsWhenBatchExceedsLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
defer wg.Done()
defer func() {
if recover() == nil {
t.Fatal("expected Add(2) to panic in strict mode when it exceeds the limit")
}
}()
wg.Add(2)
}
func TestWaitGroupLooseAddExpandsLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.SetStrictAddMode(false)
wg.Add(1)
wg.Add(2)
if got := wg.GetMaxWaitNum(); got != 3 {
t.Fatalf("unexpected max count: got %d want 3", got)
}
for i := 0; i < 3; i++ {
wg.Done()
}
wg.Wait()
}
func TestWaitGroupSetMaxWaitNumOverridesValue(t *testing.T) {
wg := NewWaitGroup(1)
wg.SetMaxWaitNum(4)
if got := wg.GetMaxWaitNum(); got != 4 {
t.Fatalf("unexpected max count: got %d want 4", got)
}
}
func TestWaitGroupSetMaxWaitNumZeroUnblocksWaitingAdd(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
unblocked := make(chan struct{})
go func() {
wg.Add(1)
close(unblocked)
wg.Done()
}()
select {
case <-unblocked:
t.Fatal("Add(1) should still be blocked before removing the limit")
case <-time.After(30 * time.Millisecond):
}
wg.SetMaxWaitNum(0)
select {
case <-unblocked:
case <-time.After(200 * time.Millisecond):
t.Fatal("Add(1) did not resume after the limit was removed")
}
wg.Done()
wg.Wait()
}
func TestWaitGroupGoWaitsForCompletion(t *testing.T) {
wg := NewWaitGroup(1)
done := make(chan struct{})
wg.Go(func() {
close(done)
})
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("Go(fn) did not run the function")
}
finished := make(chan struct{})
go func() {
wg.Wait()
close(finished)
}()
select {
case <-finished:
case <-time.After(200 * time.Millisecond):
t.Fatal("Wait did not observe Go(fn) completion")
}
}
func TestWaitGroupGoRespectsLimit(t *testing.T) {
wg := NewWaitGroup(1)
releaseFirst := make(chan struct{})
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
secondReturned := make(chan struct{})
wg.Go(func() {
close(firstStarted)
<-releaseFirst
})
select {
case <-firstStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("first Go(fn) did not start")
}
go func() {
wg.Go(func() {
close(secondStarted)
})
close(secondReturned)
}()
select {
case <-secondReturned:
t.Fatal("second Go(fn) should block while the limit is full")
case <-time.After(30 * time.Millisecond):
}
close(releaseFirst)
select {
case <-secondStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("second Go(fn) did not start after capacity was released")
}
wg.Wait()
}