fix:修复一些边界性问题
This commit is contained in:
parent
3c83d1d39f
commit
3add9183b3
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
.sentrux/
|
||||
agent_readme.md
|
||||
target.md
|
||||
.gocache
|
||||
.gocache/
|
||||
.tmp_*/
|
||||
.idea
|
||||
234
circle.go
234
circle.go
@ -2,151 +2,191 @@ package stario
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be greater than zero")
|
||||
var ErrStarBufferClosed = errors.New("star buffer closed")
|
||||
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
|
||||
|
||||
type StarBuffer struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
io.Closer
|
||||
datas []byte
|
||||
pStart uint64
|
||||
pEnd uint64
|
||||
cap uint64
|
||||
isClose atomic.Value
|
||||
isEnd atomic.Value
|
||||
rmu sync.Mutex
|
||||
wmu sync.Mutex
|
||||
datas []byte
|
||||
pStart uint64
|
||||
pEnd uint64
|
||||
size uint64
|
||||
cap uint64
|
||||
isClose bool
|
||||
isWriteEnd bool
|
||||
mu sync.Mutex
|
||||
notEmpty *sync.Cond
|
||||
notFull *sync.Cond
|
||||
}
|
||||
|
||||
func NewStarBuffer(cap uint64) *StarBuffer {
|
||||
rtnBuffer := new(StarBuffer)
|
||||
rtnBuffer.cap = cap
|
||||
rtnBuffer.datas = make([]byte, cap)
|
||||
rtnBuffer.isClose.Store(false)
|
||||
rtnBuffer.isEnd.Store(false)
|
||||
return rtnBuffer
|
||||
func NewStarBuffer(cap uint64) (*StarBuffer, error) {
|
||||
if cap == 0 {
|
||||
return nil, ErrStarBufferInvalidCapacity
|
||||
}
|
||||
rtnBuffer := &StarBuffer{
|
||||
cap: cap,
|
||||
datas: make([]byte, cap),
|
||||
}
|
||||
rtnBuffer.notEmpty = sync.NewCond(&rtnBuffer.mu)
|
||||
rtnBuffer.notFull = sync.NewCond(&rtnBuffer.mu)
|
||||
return rtnBuffer, nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Free() uint64 {
|
||||
return star.cap - star.Len()
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
return star.cap - star.size
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Cap() uint64 {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
return star.cap
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Len() uint64 {
|
||||
if star.pEnd >= star.pStart {
|
||||
return star.pEnd - star.pStart
|
||||
}
|
||||
return star.pEnd - star.pStart + star.cap
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
return star.size
|
||||
}
|
||||
|
||||
func (star *StarBuffer) getByte() (byte, error) {
|
||||
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
for star.size == 0 && !star.isWriteEnd && !star.isClose {
|
||||
star.notEmpty.Wait()
|
||||
}
|
||||
if star.size == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if star.Len() == 0 {
|
||||
return 0, os.ErrNotExist
|
||||
}
|
||||
nowPtr := star.pStart
|
||||
nextPtr := star.pStart + 1
|
||||
if nextPtr >= star.cap {
|
||||
nextPtr = 0
|
||||
}
|
||||
data := star.datas[nowPtr]
|
||||
ok := atomic.CompareAndSwapUint64(&star.pStart, nowPtr, nextPtr)
|
||||
if !ok {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
data := star.datas[star.pStart]
|
||||
star.pStart = (star.pStart + 1) % star.cap
|
||||
star.size--
|
||||
star.notFull.Broadcast()
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) putByte(data byte) error {
|
||||
if star.isClose.Load().(bool) {
|
||||
return io.EOF
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
for star.size == star.cap && !star.isClose && !star.isWriteEnd {
|
||||
star.notFull.Wait()
|
||||
}
|
||||
nowPtr := star.pEnd
|
||||
kariEnd := nowPtr + 1
|
||||
if kariEnd == star.cap {
|
||||
kariEnd = 0
|
||||
if star.isClose {
|
||||
return ErrStarBufferClosed
|
||||
}
|
||||
if kariEnd == atomic.LoadUint64(&star.pStart) {
|
||||
for {
|
||||
time.Sleep(time.Microsecond)
|
||||
runtime.Gosched()
|
||||
if kariEnd != atomic.LoadUint64(&star.pStart) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if star.isWriteEnd {
|
||||
return ErrStarBufferWriteClosed
|
||||
}
|
||||
star.datas[nowPtr] = data
|
||||
if ok := atomic.CompareAndSwapUint64(&star.pEnd, nowPtr, kariEnd); !ok {
|
||||
return os.ErrInvalid
|
||||
star.datas[star.pEnd] = data
|
||||
star.pEnd = (star.pEnd + 1) % star.cap
|
||||
star.size++
|
||||
star.notEmpty.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) EndWrite() error {
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
if star.isClose {
|
||||
return ErrStarBufferClosed
|
||||
}
|
||||
star.isWriteEnd = true
|
||||
star.notEmpty.Broadcast()
|
||||
star.notFull.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Close() error {
|
||||
star.isClose.Store(true)
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
star.isClose = true
|
||||
star.isWriteEnd = true
|
||||
star.notEmpty.Broadcast()
|
||||
star.notFull.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Read(buf []byte) (int, error) {
|
||||
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if buf == nil {
|
||||
return 0, errors.New("buffer is nil")
|
||||
}
|
||||
star.rmu.Lock()
|
||||
defer star.rmu.Unlock()
|
||||
var sum int = 0
|
||||
for i := 0; i < len(buf); i++ {
|
||||
data, err := star.getByte()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if sum == 0 {
|
||||
return sum, err
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
if err == os.ErrNotExist {
|
||||
i--
|
||||
continue
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
buf[i] = data
|
||||
sum++
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
for star.size == 0 && !star.isWriteEnd && !star.isClose {
|
||||
star.notEmpty.Wait()
|
||||
}
|
||||
if star.size == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
sum := minInt(len(buf), int(star.size))
|
||||
first := minInt(sum, int(star.cap-star.pStart))
|
||||
copy(buf, star.datas[star.pStart:star.pStart+uint64(first)])
|
||||
second := sum - first
|
||||
if second > 0 {
|
||||
copy(buf[first:], star.datas[:second])
|
||||
}
|
||||
star.pStart = (star.pStart + uint64(sum)) % star.cap
|
||||
star.size -= uint64(sum)
|
||||
star.notFull.Broadcast()
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func (star *StarBuffer) Write(bts []byte) (int, error) {
|
||||
if bts == nil && !star.isEnd.Load().(bool) {
|
||||
star.isEnd.Store(true)
|
||||
if bts == nil {
|
||||
return 0, star.EndWrite()
|
||||
}
|
||||
if len(bts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if bts == nil || star.isClose.Load().(bool) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
star.wmu.Lock()
|
||||
defer star.wmu.Unlock()
|
||||
var sum = 0
|
||||
for i := 0; i < len(bts); i++ {
|
||||
err := star.putByte(bts[i])
|
||||
if err != nil {
|
||||
fmt.Println("Write bts err:", err)
|
||||
return sum, err
|
||||
star.mu.Lock()
|
||||
defer star.mu.Unlock()
|
||||
sum := 0
|
||||
for sum < len(bts) {
|
||||
for star.size == star.cap && !star.isClose && !star.isWriteEnd {
|
||||
star.notFull.Wait()
|
||||
}
|
||||
sum++
|
||||
if star.isClose {
|
||||
if sum == 0 {
|
||||
return 0, ErrStarBufferClosed
|
||||
}
|
||||
return sum, ErrStarBufferClosed
|
||||
}
|
||||
if star.isWriteEnd {
|
||||
if sum == 0 {
|
||||
return 0, ErrStarBufferWriteClosed
|
||||
}
|
||||
return sum, ErrStarBufferWriteClosed
|
||||
}
|
||||
space := int(star.cap - star.size)
|
||||
if space == 0 {
|
||||
continue
|
||||
}
|
||||
n := minInt(len(bts)-sum, space)
|
||||
first := minInt(n, int(star.cap-star.pEnd))
|
||||
copy(star.datas[star.pEnd:star.pEnd+uint64(first)], bts[sum:sum+first])
|
||||
second := n - first
|
||||
if second > 0 {
|
||||
copy(star.datas[:second], bts[sum+first:sum+n])
|
||||
}
|
||||
star.pEnd = (star.pEnd + uint64(n)) % star.cap
|
||||
star.size += uint64(n)
|
||||
sum += n
|
||||
star.notEmpty.Broadcast()
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func minInt(a int, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@ -1,14 +1,91 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
|
||||
buf, err := NewStarBuffer(0)
|
||||
if err != ErrStarBufferInvalidCapacity {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if buf != nil {
|
||||
t.Fatal("expected nil buffer when capacity is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := buf.Write([]byte("abcd")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := buf.EndWrite(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := make([]byte, 4)
|
||||
n, err := buf.Read(got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 4 || !bytes.Equal(got[:n], []byte("abcd")) {
|
||||
t.Fatalf("unexpected payload: n=%d data=%q", n, got[:n])
|
||||
}
|
||||
|
||||
n, err = buf.Read(got)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("expected EOF after draining buffer, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
|
||||
t.Fatalf("unexpected write error after EndWrite: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStarBufferCloseAllowsDrain(t *testing.T) {
|
||||
buf, err := NewStarBuffer(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := buf.Write([]byte("ab")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := buf.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := make([]byte, 2)
|
||||
n, err := buf.Read(got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
|
||||
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n])
|
||||
}
|
||||
|
||||
n, err = buf.Read(got)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
|
||||
t.Fatalf("unexpected write error after Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Circle(t *testing.T) {
|
||||
buf := NewStarBuffer(2048)
|
||||
buf, err := NewStarBuffer(2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
//fmt.Println("write start")
|
||||
@ -37,7 +114,10 @@ func Test_Circle(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Circle_Speed(t *testing.T) {
|
||||
buf := NewStarBuffer(1048976)
|
||||
buf, err := NewStarBuffer(1048976)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := uint64(0)
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
@ -61,7 +141,10 @@ func Test_Circle_Speed(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Circle_Speed2(t *testing.T) {
|
||||
buf := NewStarBuffer(8192)
|
||||
buf, err := NewStarBuffer(8192)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := uint64(0)
|
||||
for i := 1; i <= 10; i++ {
|
||||
go func() {
|
||||
|
||||
312
io.go
312
io.go
@ -16,12 +16,20 @@ type InputMsg struct {
|
||||
skipSliceSigErr bool
|
||||
}
|
||||
|
||||
type rawTerminalSession struct {
|
||||
fd int
|
||||
state *terminal.State
|
||||
reader *bufio.Reader
|
||||
redrawHint string
|
||||
printNewline bool
|
||||
}
|
||||
|
||||
func Passwd(hint string, defaultVal string) InputMsg {
|
||||
return passwd(hint, defaultVal, "", false)
|
||||
return passwd(hint, defaultVal, "", true)
|
||||
}
|
||||
|
||||
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
|
||||
return passwd(hint, defaultVal, mask, false)
|
||||
return passwd(hint, defaultVal, mask, true)
|
||||
}
|
||||
|
||||
func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
|
||||
@ -36,54 +44,162 @@ func MessageBoxRaw(hint string, defaultVal string) InputMsg {
|
||||
return messageBox(hint, defaultVal)
|
||||
}
|
||||
|
||||
func messageBox(hint string, defaultVal string) InputMsg {
|
||||
var ioBuf []rune
|
||||
func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession, error) {
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
if strings.Index(hint, "\n") >= 0 {
|
||||
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
|
||||
}
|
||||
fd := int(os.Stdin.Fd())
|
||||
state, err := terminal.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
return nil, err
|
||||
}
|
||||
defer fmt.Println()
|
||||
defer terminal.Restore(fd, state)
|
||||
inputReader := bufio.NewReader(os.Stdin)
|
||||
for {
|
||||
b, _, err := inputReader.ReadRune()
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
if b == 0x0d {
|
||||
strValue := strings.TrimSpace(string(ioBuf))
|
||||
if len(strValue) == 0 {
|
||||
strValue = defaultVal
|
||||
}
|
||||
return InputMsg{msg: strValue, err: err}
|
||||
}
|
||||
if b == 0x08 || b == 0x7F {
|
||||
if len(ioBuf) > 0 {
|
||||
ioBuf = ioBuf[:len(ioBuf)-1]
|
||||
}
|
||||
fmt.Print("\r")
|
||||
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
} else {
|
||||
ioBuf = append(ioBuf, b)
|
||||
}
|
||||
fmt.Print("\r")
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
fmt.Print(string(ioBuf))
|
||||
return &rawTerminalSession{
|
||||
fd: fd,
|
||||
state: state,
|
||||
reader: bufio.NewReader(os.Stdin),
|
||||
redrawHint: promptRedrawHint(hint),
|
||||
printNewline: printNewline,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (session *rawTerminalSession) Close() {
|
||||
if session == nil || session.state == nil {
|
||||
return
|
||||
}
|
||||
_ = terminal.Restore(session.fd, session.state)
|
||||
if session.printNewline {
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
func isSiganl(s rune) bool {
|
||||
func (session *rawTerminalSession) Restore() error {
|
||||
if session == nil || session.state == nil {
|
||||
return nil
|
||||
}
|
||||
return terminal.Restore(session.fd, session.state)
|
||||
}
|
||||
|
||||
func promptRedrawHint(hint string) string {
|
||||
if strings.Index(hint, "\n") >= 0 {
|
||||
hint = hint[strings.LastIndex(hint, "\n")+1:]
|
||||
}
|
||||
return strings.TrimSpace(hint)
|
||||
}
|
||||
|
||||
func finalizeInputValue(raw string, defaultVal string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if len(raw) == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func renderRawEcho(ioBuf []rune, mask string) string {
|
||||
if mask == "" {
|
||||
return string(ioBuf)
|
||||
}
|
||||
return strings.Repeat(mask, len(ioBuf))
|
||||
}
|
||||
|
||||
func redrawPromptLine(hint string, echo string, lastWidth int) int {
|
||||
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
|
||||
clearWidth := lastWidth
|
||||
if nowWidth > clearWidth {
|
||||
clearWidth = nowWidth
|
||||
}
|
||||
fmt.Print("\r")
|
||||
if clearWidth > 0 {
|
||||
fmt.Print(strings.Repeat(" ", clearWidth))
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
if echo != "" {
|
||||
fmt.Print(echo)
|
||||
}
|
||||
return nowWidth
|
||||
}
|
||||
|
||||
func stringDisplayWidth(text string) int {
|
||||
width := 0
|
||||
for _, r := range text {
|
||||
width += runeDisplayWidth(r)
|
||||
}
|
||||
return width
|
||||
}
|
||||
|
||||
func runeDisplayWidth(r rune) int {
|
||||
switch {
|
||||
case r == '\t':
|
||||
return 4
|
||||
case r < 0x20 || (r >= 0x7f && r < 0xa0):
|
||||
return 0
|
||||
case isWideRune(r):
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func isWideRune(r rune) bool {
|
||||
return r >= 0x1100 && (r <= 0x115f ||
|
||||
r == 0x2329 || r == 0x232a ||
|
||||
(r >= 0x2e80 && r <= 0xa4cf && r != 0x303f) ||
|
||||
(r >= 0xac00 && r <= 0xd7a3) ||
|
||||
(r >= 0xf900 && r <= 0xfaff) ||
|
||||
(r >= 0xfe10 && r <= 0xfe19) ||
|
||||
(r >= 0xfe30 && r <= 0xfe6f) ||
|
||||
(r >= 0xff00 && r <= 0xff60) ||
|
||||
(r >= 0xffe0 && r <= 0xffe6) ||
|
||||
(r >= 0x1f300 && r <= 0x1f64f) ||
|
||||
(r >= 0x1f900 && r <= 0x1f9ff) ||
|
||||
(r >= 0x20000 && r <= 0x3fffd))
|
||||
}
|
||||
|
||||
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
||||
session, err := newRawTerminalSession(hint, true)
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
defer session.Close()
|
||||
ioBuf := make([]rune, 0, 16)
|
||||
lastWidth := 0
|
||||
for {
|
||||
b, _, err := session.reader.ReadRune()
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
if handleSignal && isSignal(b) {
|
||||
if runtime.GOOS != "windows" {
|
||||
if err := session.Restore(); err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
}
|
||||
if err := signal(b); err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch b {
|
||||
case 0x0d, 0x0a:
|
||||
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
|
||||
case 0x08, 0x7F:
|
||||
if len(ioBuf) > 0 {
|
||||
ioBuf = ioBuf[:len(ioBuf)-1]
|
||||
}
|
||||
default:
|
||||
ioBuf = append(ioBuf, b)
|
||||
}
|
||||
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth)
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(hint string, defaultVal string) InputMsg {
|
||||
return rawLineInput(hint, defaultVal, "", false)
|
||||
}
|
||||
|
||||
func isSignal(s rune) bool {
|
||||
switch s {
|
||||
case 0x03, 0x1a, 0x1c:
|
||||
return true
|
||||
@ -93,63 +209,7 @@ func isSiganl(s rune) bool {
|
||||
}
|
||||
|
||||
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
|
||||
var ioBuf []rune
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
if strings.Index(hint, "\n") >= 0 {
|
||||
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
|
||||
}
|
||||
fd := int(os.Stdin.Fd())
|
||||
state, err := terminal.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
defer fmt.Println()
|
||||
defer terminal.Restore(fd, state)
|
||||
inputReader := bufio.NewReader(os.Stdin)
|
||||
for {
|
||||
b, _, err := inputReader.ReadRune()
|
||||
if err != nil {
|
||||
return InputMsg{msg: "", err: err}
|
||||
}
|
||||
if handleSignal && isSiganl(b) {
|
||||
if runtime.GOOS != "windows" {
|
||||
terminal.Restore(fd, state)
|
||||
}
|
||||
if err := signal(b); err != nil {
|
||||
return InputMsg{
|
||||
msg: "",
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
if b == 0x0d {
|
||||
strValue := strings.TrimSpace(string(ioBuf))
|
||||
if len(strValue) == 0 {
|
||||
strValue = defaultVal
|
||||
}
|
||||
return InputMsg{msg: strValue, err: err}
|
||||
}
|
||||
if b == 0x08 || b == 0x7F {
|
||||
if len(ioBuf) > 0 {
|
||||
ioBuf = ioBuf[:len(ioBuf)-1]
|
||||
}
|
||||
fmt.Print("\r")
|
||||
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
} else {
|
||||
ioBuf = append(ioBuf, b)
|
||||
}
|
||||
fmt.Print("\r")
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
for i := 0; i < len(ioBuf); i++ {
|
||||
fmt.Print(mask)
|
||||
}
|
||||
}
|
||||
return rawLineInput(hint, defaultVal, mask, handleSignal)
|
||||
}
|
||||
|
||||
func MessageBox(hint string, defaultVal string) InputMsg {
|
||||
@ -161,10 +221,7 @@ func MessageBox(hint string, defaultVal string) InputMsg {
|
||||
if err != nil {
|
||||
return InputMsg{msg: str, err: err}
|
||||
}
|
||||
str = strings.TrimSpace(str)
|
||||
if len(str) == 0 {
|
||||
str = defaultVal
|
||||
}
|
||||
str = finalizeInputValue(str, defaultVal)
|
||||
return InputMsg{msg: str, err: err}
|
||||
}
|
||||
|
||||
@ -386,46 +443,61 @@ func (im InputMsg) MustSliceFloat32(sep string) []float32 {
|
||||
}
|
||||
|
||||
func YesNo(hint string, defaults bool) bool {
|
||||
res, err := YesNoE(hint, defaults)
|
||||
if err != nil {
|
||||
return defaults
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func parseYesNoValue(raw string, defaults bool) (bool, bool) {
|
||||
raw = strings.TrimSpace(strings.ToUpper(raw))
|
||||
if raw == "" {
|
||||
return defaults, true
|
||||
}
|
||||
switch []rune(raw)[0] {
|
||||
case 'Y':
|
||||
return true, true
|
||||
case 'N':
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func YesNoE(hint string, defaults bool) (bool, error) {
|
||||
for {
|
||||
res := strings.ToUpper(MessageBox(hint, "").MustString())
|
||||
if res == "" {
|
||||
return defaults
|
||||
res, err := MessageBox(hint, "").String()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
res = res[0:1]
|
||||
if res == "Y" {
|
||||
return true
|
||||
} else if res == "N" {
|
||||
return false
|
||||
if answer, ok := parseYesNoValue(res, defaults); ok {
|
||||
return answer, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StopUntil(hint string, trigger string, repeat bool) error {
|
||||
pressLen := len([]rune(trigger))
|
||||
triggerRunes := []rune(trigger)
|
||||
pressLen := len(triggerRunes)
|
||||
if trigger == "" {
|
||||
pressLen = 1
|
||||
}
|
||||
fd := int(os.Stdin.Fd())
|
||||
if hint != "" {
|
||||
fmt.Print(hint)
|
||||
}
|
||||
state, err := terminal.MakeRaw(fd)
|
||||
session, err := newRawTerminalSession(hint, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer terminal.Restore(fd, state)
|
||||
inputReader := bufio.NewReader(os.Stdin)
|
||||
//ioBuf := make([]byte, pressLen)
|
||||
defer session.Close()
|
||||
i := 0
|
||||
for {
|
||||
b, _, err := inputReader.ReadRune()
|
||||
b, _, err := session.reader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if trigger == "" {
|
||||
break
|
||||
}
|
||||
if b == []rune(trigger)[i] {
|
||||
if b == triggerRunes[i] {
|
||||
i++
|
||||
if i == pressLen {
|
||||
break
|
||||
|
||||
35
io_test.go
35
io_test.go
@ -5,6 +5,41 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPromptRedrawHint(t *testing.T) {
|
||||
got := promptRedrawHint("头部提示\n 中文确认: ")
|
||||
if got != "中文确认:" {
|
||||
t.Fatalf("unexpected redraw hint: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringDisplayWidth(t *testing.T) {
|
||||
got := stringDisplayWidth("中a[]")
|
||||
if got != 5 {
|
||||
t.Fatalf("unexpected display width: got %d want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseYesNoValue(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
defaults bool
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{name: "default", input: " ", defaults: true, want: true, ok: true},
|
||||
{name: "yes", input: "yes", defaults: false, want: true, ok: true},
|
||||
{name: "no", input: "No", defaults: true, want: false, ok: true},
|
||||
{name: "invalid", input: "maybe", defaults: false, want: false, ok: false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, ok := parseYesNoValue(tc.input, tc.defaults)
|
||||
if got != tc.want || ok != tc.ok {
|
||||
t.Fatalf("%s: got (%v, %v) want (%v, %v)", tc.name, got, ok, tc.want, tc.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Slice(t *testing.T) {
|
||||
var data = InputMsg{
|
||||
msg: "true,false,true,true,false,0,1,hello",
|
||||
|
||||
131
sync.go
131
sync.go
@ -1,55 +1,142 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type waitGroupAddMode uint8
|
||||
|
||||
const (
|
||||
waitGroupAddModeStrict waitGroupAddMode = iota
|
||||
waitGroupAddModeLoose
|
||||
)
|
||||
|
||||
type WaitGroup struct {
|
||||
wg *sync.WaitGroup
|
||||
maxCount uint32
|
||||
allCount uint32
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
initOnce sync.Once
|
||||
maxCount int
|
||||
running int
|
||||
addMode waitGroupAddMode
|
||||
}
|
||||
|
||||
func NewWaitGroup(maxCount int) WaitGroup {
|
||||
return WaitGroup{wg: &sync.WaitGroup{}, maxCount: uint32(maxCount)}
|
||||
if maxCount < 0 {
|
||||
panic("stario: negative max wait count")
|
||||
}
|
||||
return WaitGroup{
|
||||
maxCount: maxCount,
|
||||
addMode: waitGroupAddModeStrict,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WaitGroup) init() {
|
||||
w.initOnce.Do(func() {
|
||||
w.cond = sync.NewCond(&w.mu)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WaitGroup) Add(delta int) {
|
||||
var Udelta uint32
|
||||
w.init()
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
if delta < 0 {
|
||||
Udelta = uint32(-delta - 1)
|
||||
} else {
|
||||
Udelta = uint32(delta)
|
||||
w.release(-delta)
|
||||
return
|
||||
}
|
||||
for {
|
||||
allC := atomic.LoadUint32(&w.allCount)
|
||||
if atomic.LoadUint32(&w.maxCount) == 0 || atomic.LoadUint32(&w.maxCount) >= allC+uint32(delta) {
|
||||
if delta < 0 {
|
||||
atomic.AddUint32(&w.allCount, ^uint32(Udelta))
|
||||
} else {
|
||||
atomic.AddUint32(&w.allCount, uint32(Udelta))
|
||||
}
|
||||
break
|
||||
w.acquire(delta)
|
||||
}
|
||||
|
||||
func (w *WaitGroup) acquire(delta int) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.maxCount <= 0 {
|
||||
w.running += delta
|
||||
w.wg.Add(delta)
|
||||
return
|
||||
}
|
||||
if delta == 1 {
|
||||
w.wg.Add(1)
|
||||
for w.maxCount > 0 && w.running >= w.maxCount {
|
||||
w.cond.Wait()
|
||||
}
|
||||
time.Sleep(time.Microsecond)
|
||||
w.running++
|
||||
return
|
||||
}
|
||||
if w.running+delta > w.maxCount {
|
||||
if w.addMode == waitGroupAddModeStrict {
|
||||
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
|
||||
}
|
||||
w.maxCount = w.running + delta
|
||||
}
|
||||
w.running += delta
|
||||
w.wg.Add(delta)
|
||||
}
|
||||
|
||||
func (w *WaitGroup) release(delta int) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if delta > w.running {
|
||||
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
|
||||
}
|
||||
w.wg.Add(-delta)
|
||||
w.running -= delta
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (w *WaitGroup) Done() {
|
||||
w.Add(-1)
|
||||
}
|
||||
|
||||
func (w *WaitGroup) Go(fn func()) {
|
||||
w.Add(1)
|
||||
go func() {
|
||||
defer w.Done()
|
||||
fn()
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WaitGroup) Wait() {
|
||||
w.init()
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
func (w *WaitGroup) GetMaxWaitNum() int {
|
||||
return int(atomic.LoadUint32(&w.maxCount))
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.maxCount
|
||||
}
|
||||
|
||||
func (w *WaitGroup) SetMaxWaitNum(num int) {
|
||||
atomic.AddUint32(&w.maxCount, uint32(num))
|
||||
if num < 0 {
|
||||
panic("stario: negative max wait count")
|
||||
}
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
w.maxCount = num
|
||||
w.mu.Unlock()
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
if strict {
|
||||
w.addMode = waitGroupAddModeStrict
|
||||
} else {
|
||||
w.addMode = waitGroupAddModeLoose
|
||||
}
|
||||
w.mu.Unlock()
|
||||
w.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (w *WaitGroup) StrictAddMode() bool {
|
||||
w.init()
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.addMode == waitGroupAddModeStrict
|
||||
}
|
||||
|
||||
171
sync_test.go
Normal file
171
sync_test.go
Normal file
@ -0,0 +1,171 @@
|
||||
package stario
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWaitGroupAddBlocksAtLimit(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
wg.Add(1)
|
||||
|
||||
unblocked := make(chan struct{})
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
close(unblocked)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-unblocked:
|
||||
t.Fatal("Add(1) should block while running count is at the limit")
|
||||
case <-time.After(30 * time.Millisecond):
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
|
||||
select {
|
||||
case <-unblocked:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Add(1) did not resume after capacity was released")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWaitGroupStrictAddPanicsWhenBatchExceedsLimit(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected Add(2) to panic in strict mode when it exceeds the limit")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(2)
|
||||
}
|
||||
|
||||
func TestWaitGroupLooseAddExpandsLimit(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
wg.SetStrictAddMode(false)
|
||||
|
||||
wg.Add(1)
|
||||
wg.Add(2)
|
||||
|
||||
if got := wg.GetMaxWaitNum(); got != 3 {
|
||||
t.Fatalf("unexpected max count: got %d want 3", got)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Done()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWaitGroupSetMaxWaitNumOverridesValue(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
wg.SetMaxWaitNum(4)
|
||||
if got := wg.GetMaxWaitNum(); got != 4 {
|
||||
t.Fatalf("unexpected max count: got %d want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitGroupSetMaxWaitNumZeroUnblocksWaitingAdd(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
wg.Add(1)
|
||||
|
||||
unblocked := make(chan struct{})
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
close(unblocked)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-unblocked:
|
||||
t.Fatal("Add(1) should still be blocked before removing the limit")
|
||||
case <-time.After(30 * time.Millisecond):
|
||||
}
|
||||
|
||||
wg.SetMaxWaitNum(0)
|
||||
|
||||
select {
|
||||
case <-unblocked:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Add(1) did not resume after the limit was removed")
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWaitGroupGoWaitsForCompletion(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
done := make(chan struct{})
|
||||
|
||||
wg.Go(func() {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Go(fn) did not run the function")
|
||||
}
|
||||
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Wait did not observe Go(fn) completion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitGroupGoRespectsLimit(t *testing.T) {
|
||||
wg := NewWaitGroup(1)
|
||||
releaseFirst := make(chan struct{})
|
||||
firstStarted := make(chan struct{})
|
||||
secondStarted := make(chan struct{})
|
||||
secondReturned := make(chan struct{})
|
||||
|
||||
wg.Go(func() {
|
||||
close(firstStarted)
|
||||
<-releaseFirst
|
||||
})
|
||||
|
||||
select {
|
||||
case <-firstStarted:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("first Go(fn) did not start")
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
close(secondStarted)
|
||||
})
|
||||
close(secondReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-secondReturned:
|
||||
t.Fatal("second Go(fn) should block while the limit is full")
|
||||
case <-time.After(30 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseFirst)
|
||||
|
||||
select {
|
||||
case <-secondStarted:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("second Go(fn) did not start after capacity was released")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user