bug fix:goroutine security improved

master
兔子 4 years ago
parent 79dcaaf249
commit 07e374b83f

@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"b612.me/starnet" "b612.me/starnet"
@ -17,6 +18,7 @@ type StarNotifyC struct {
Connc net.Conn Connc net.Conn
dialTimeout time.Duration dialTimeout time.Duration
clientSign map[string]chan string clientSign map[string]chan string
mu sync.Mutex
// FuncLists 当不使用channel时使用此记录调用函数 // FuncLists 当不使用channel时使用此记录调用函数
FuncLists map[string]func(CMsg) FuncLists map[string]func(CMsg)
stopSign context.Context stopSign context.Context
@ -62,7 +64,9 @@ func (star *StarNotifyC) starinitc() {
func (star *StarNotifyC) Notify(key string) chan string { func (star *StarNotifyC) Notify(key string) chan string {
if _, ok := star.clientSign[key]; !ok { if _, ok := star.clientSign[key]; !ok {
ch := make(chan string, 20) ch := make(chan string, 20)
star.mu.Lock()
star.clientSign[key] = ch star.clientSign[key] = ch
star.mu.Unlock()
} }
return star.clientSign[key] return star.clientSign[key]
} }
@ -71,7 +75,9 @@ func (star *StarNotifyC) store(key, value string) {
if _, ok := star.clientSign[key]; !ok { if _, ok := star.clientSign[key]; !ok {
ch := make(chan string, 20) ch := make(chan string, 20)
ch <- value ch <- value
star.mu.Lock()
star.clientSign[key] = ch star.clientSign[key] = ch
star.mu.Unlock()
return return
} }
star.clientSign[key] <- value star.clientSign[key] <- value
@ -212,7 +218,7 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration)
return CMsg{}, errors.New("Do Not Use UseChannel Mode!") return CMsg{}, errors.New("Do Not Use UseChannel Mode!")
} }
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
mode := "cr" + fmt.Sprintf("%05d", rand.Intn(99999)) mode := "cr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999))
var key []byte var key []byte
for _, v := range []byte(name) { for _, v := range []byte(name) {
if v == byte(124) || v == byte(92) { if v == byte(124) || v == byte(92) {
@ -229,11 +235,15 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration)
} }
var source CMsg var source CMsg
source.wait = make(chan int, 2) source.wait = make(chan int, 2)
star.mu.Lock()
star.lockPool[mode] = source star.lockPool[mode] = source
star.mu.Unlock()
select { select {
case <-source.wait: case <-source.wait:
res := star.lockPool[mode] res := star.lockPool[mode]
star.mu.Lock()
delete(star.lockPool, mode) delete(star.lockPool, mode)
star.mu.Unlock()
return res, nil return res, nil
case <-tmceed: case <-tmceed:
return CMsg{}, errors.New("Time Exceed") return CMsg{}, errors.New("Time Exceed")
@ -263,7 +273,7 @@ func (star *StarNotifyC) cnotify() {
} }
data, err := star.Queue.RestoreOne() data, err := star.Queue.RestoreOne()
if err != nil { if err != nil {
time.Sleep(time.Millisecond * 20) time.Sleep(time.Microsecond * 2)
continue continue
} }
if string(data.Msg) == "b612ryzstop" { if string(data.Msg) == "b612ryzstop" {
@ -301,7 +311,9 @@ func (star *StarNotifyC) cnotify() {
sa.Key = key sa.Key = key
sa.Value = value sa.Value = value
sa.mode = mode sa.mode = mode
star.mu.Lock()
star.lockPool[mode] = sa star.lockPool[mode] = sa
star.mu.Unlock()
sa.wait <- 1 sa.wait <- 1
} else { } else {
if msg, ok := star.FuncLists[key]; ok { if msg, ok := star.FuncLists[key]; ok {

@ -40,7 +40,7 @@ func Test_usechannel(t *testing.T) {
} }
func Test_nochannel(t *testing.T) { func Test_nochannel(t *testing.T) {
server, err := NewNotifyS("udp", "127.0.0.1:1926") server, err := NewNotifyS("tcp", "127.0.0.1:1926")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
@ -53,7 +53,7 @@ func Test_nochannel(t *testing.T) {
} }
return "" return ""
}) })
client, err := NewNotifyC("udp", "127.0.0.1:1926") client, err := NewNotifyC("tcp", "127.0.0.1:1926")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return

@ -27,3 +27,11 @@ func Decode(src []byte) (interface{}, error) {
err := dec.Decode(&dst) err := dec.Decode(&dst)
return dst, err return dst, err
} }
func (nmsg *SMsg) Decode() (interface{}, error) {
return Decode([]byte(nmsg.Value))
}
func (nmsg *CMsg) Decode() (interface{}, error) {
return Decode([]byte(nmsg.Value))
}

@ -8,6 +8,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"b612.me/starcrypto" "b612.me/starcrypto"
@ -39,13 +40,16 @@ type StarNotifyS struct {
Queue *starnet.StarQueue Queue *starnet.StarQueue
// FuncLists 记录了被通知项所记录的函数 // FuncLists 记录了被通知项所记录的函数
FuncLists map[string]func(SMsg) string FuncLists map[string]func(SMsg) string
funcMu sync.Mutex
defaultFunc func(SMsg) string defaultFunc func(SMsg) string
Connected func(SMsg) Connected func(SMsg)
nickName map[string]string nickName map[string]string
stopSign context.Context stopSign context.Context
cancel context.CancelFunc cancel context.CancelFunc
connPool map[string]net.Conn connPool sync.Map
connMu sync.Mutex
lockPool map[string]SMsg lockPool map[string]SMsg
lockMu sync.Mutex
udpPool map[string]*net.UDPAddr udpPool map[string]*net.UDPAddr
listener net.Listener listener net.Listener
isUDP bool isUDP bool
@ -90,19 +94,22 @@ func (star *StarNotifyS) getName(conn string) string {
// GetConnPool 获取所有Client端信息 // GetConnPool 获取所有Client端信息
func (star *StarNotifyS) GetConnPool() []SMsg { func (star *StarNotifyS) GetConnPool() []SMsg {
var result []SMsg var result []SMsg
for _, v := range star.connPool { star.connPool.Range(func(k, val interface{}) bool {
v := val.(net.Conn)
result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName}) result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName})
} return true
})
for _, v := range star.udpPool { for _, v := range star.udpPool {
result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName}) result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName})
} }
return result return result
} }
// GetConnPool 获取所有Client端信息 // GetClient 获取所有Client端信息
func (star *StarNotifyS) GetClient(name string) (SMsg, error) { func (star *StarNotifyS) GetClient(name string) (SMsg, error) {
if str, ok := star.nickName[name]; ok { if str, ok := star.nickName[name]; ok {
if conn, ok := star.connPool[str]; ok { if tmp, ok := star.connPool.Load(str); ok {
conn := tmp.(net.Conn)
return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName}, nil return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName}, nil
} }
if conn, ok := star.udpPool[str]; ok { if conn, ok := star.udpPool[str]; ok {
@ -188,7 +195,7 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur
var err error var err error
var tmceed <-chan time.Time var tmceed <-chan time.Time
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
mode := "sr" + fmt.Sprintf("%05d", rand.Intn(99999)) mode := "sr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999))
if source.uconn == nil { if source.uconn == nil {
_, err = source.Conn.Write(builder.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value))) _, err = source.Conn.Write(builder.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value)))
} else { } else {
@ -201,11 +208,15 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur
tmceed = time.After(tmout) tmceed = time.After(tmout)
} }
source.wait = make(chan int, 2) source.wait = make(chan int, 2)
star.lockMu.Lock()
star.lockPool[mode] = source star.lockPool[mode] = source
star.lockMu.Unlock()
select { select {
case <-source.wait: case <-source.wait:
star.lockMu.Lock()
res := star.lockPool[mode] res := star.lockPool[mode]
delete(star.lockPool, mode) delete(star.lockPool, mode)
star.lockMu.Unlock()
return res, nil return res, nil
case <-tmceed: case <-tmceed:
return SMsg{}, errors.New("Time Exceed") return SMsg{}, errors.New("Time Exceed")
@ -220,7 +231,6 @@ func (star *StarNotifyS) starinits() {
star.Queue.Encode = true star.Queue.Encode = true
star.udpPool = make(map[string]*net.UDPAddr) star.udpPool = make(map[string]*net.UDPAddr)
star.FuncLists = make(map[string]func(SMsg) string) star.FuncLists = make(map[string]func(SMsg) string)
star.connPool = make(map[string]net.Conn)
star.nickName = make(map[string]string) star.nickName = make(map[string]string)
star.lockPool = make(map[string]SMsg) star.lockPool = make(map[string]SMsg)
star.Stop = make(chan int, 5) star.Stop = make(chan int, 5)
@ -253,7 +263,9 @@ func doudps(netype, value string) (*StarNotifyS, error) {
<-star.stopSign.Done() <-star.stopSign.Done()
for k, v := range star.udpPool { for k, v := range star.udpPool {
star.UDPConn.WriteToUDP(star.Queue.BuildMessage([]byte("b612ryzstop")), v) star.UDPConn.WriteToUDP(star.Queue.BuildMessage([]byte("b612ryzstop")), v)
star.connMu.Lock()
delete(star.udpPool, k) delete(star.udpPool, k)
star.connMu.Unlock()
for k2, v2 := range star.nickName { for k2, v2 := range star.nickName {
if v2 == k { if v2 == k {
delete(star.nickName, k2) delete(star.nickName, k2)
@ -275,7 +287,9 @@ func doudps(netype, value string) (*StarNotifyS, error) {
go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName}) go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName})
} }
} }
star.connMu.Lock()
star.udpPool[addr.String()] = addr star.udpPool[addr.String()] = addr
star.connMu.Unlock()
} }
if err != nil { if err != nil {
continue continue
@ -298,15 +312,20 @@ func notudps(netype, value string) (*StarNotifyS, error) {
go star.notify() go star.notify()
go func() { go func() {
<-star.stopSign.Done() <-star.stopSign.Done()
for k, v := range star.connPool { star.connPool.Range(func(a, b interface{}) bool {
k := a.(string)
v := b.(net.Conn)
v.Close() v.Close()
delete(star.connPool, k) star.connPool.Delete(a)
for k2, v2 := range star.nickName { for k2, v2 := range star.nickName {
if v2 == k { if v2 == k {
star.funcMu.Lock()
delete(star.nickName, k2) delete(star.nickName, k2)
star.funcMu.Unlock()
} }
} }
} return true
})
star.listener.Close() star.listener.Close()
star.Online = false star.Online = false
return return
@ -341,7 +360,7 @@ func notudps(netype, value string) (*StarNotifyS, error) {
} }
if err != nil { if err != nil {
conn.Close() conn.Close()
delete(star.connPool, fmt.Sprint(conn)) star.connPool.Delete(fmt.Sprint(conn))
for k, v := range star.nickName { for k, v := range star.nickName {
if v == fmt.Sprint(conn) { if v == fmt.Sprint(conn) {
delete(star.nickName, k) delete(star.nickName, k)
@ -351,7 +370,7 @@ func notudps(netype, value string) (*StarNotifyS, error) {
} }
} }
}(conn) }(conn)
star.connPool[fmt.Sprint(conn)] = conn star.connPool.Store(fmt.Sprint(conn), conn)
if star.Connected != nil { if star.Connected != nil {
go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName}) go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName})
} }
@ -367,7 +386,7 @@ func (star *StarNotifyS) GetListenerInfo() net.Listener {
// SetNotify 用于设置通知关键词的调用函数 // SetNotify 用于设置通知关键词的调用函数
func (star *StarNotifyS) setNickName(name string, conn string) error { func (star *StarNotifyS) setNickName(name string, conn string) error {
if _, ok := star.connPool[conn]; !ok { if _, ok := star.connPool.Load(conn); !ok {
if _, ok := star.udpPool[conn]; !ok { if _, ok := star.udpPool[conn]; !ok {
return errors.New("Conn Not Found") return errors.New("Conn Not Found")
} }
@ -377,12 +396,22 @@ func (star *StarNotifyS) setNickName(name string, conn string) error {
delete(star.nickName, k) delete(star.nickName, k)
} }
} }
star.funcMu.Lock()
star.nickName[name] = conn star.nickName[name] = conn
star.funcMu.Unlock()
return nil return nil
} }
// SetNotify 用于设置通知关键词的调用函数 // SetNotify 用于设置通知关键词的调用函数
func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) { func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) {
star.funcMu.Lock()
defer star.funcMu.Unlock()
if data == nil {
if _, ok := star.FuncLists[name]; ok {
delete(star.FuncLists, name)
}
return
}
star.FuncLists[name] = data star.FuncLists[name] = data
} }
@ -414,7 +443,7 @@ func (star *StarNotifyS) notify() {
} }
data, err := star.Queue.RestoreOne() data, err := star.Queue.RestoreOne()
if err != nil { if err != nil {
time.Sleep(time.Millisecond * 20) time.Sleep(time.Microsecond * 2)
continue continue
} }
mode, key, value := star.analyseData(string(data.Msg)) mode, key, value := star.analyseData(string(data.Msg))
@ -424,7 +453,9 @@ func (star *StarNotifyS) notify() {
} else { } else {
rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName} rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName}
if key == "b612ryzstop" { if key == "b612ryzstop" {
star.connMu.Lock()
delete(star.udpPool, rmsg.UDP.String()) delete(star.udpPool, rmsg.UDP.String())
star.connMu.Unlock()
for k, v := range star.nickName { for k, v := range star.nickName {
if v == rmsg.UDP.String() { if v == rmsg.UDP.String() {
delete(star.nickName, k) delete(star.nickName, k)
@ -451,7 +482,7 @@ func (star *StarNotifyS) notify() {
} }
} }
if mode[0:2] != "sr" { if mode[0:2] != "sr" {
if star.Sync { if !star.Sync {
go replyFunc(key, rmsg) go replyFunc(key, rmsg)
} else { } else {
replyFunc(key, rmsg) replyFunc(key, rmsg)
@ -459,10 +490,12 @@ func (star *StarNotifyS) notify() {
} else { } else {
if sa, ok := star.lockPool[mode]; ok { if sa, ok := star.lockPool[mode]; ok {
rmsg.wait = sa.wait rmsg.wait = sa.wait
star.lockMu.Lock()
star.lockPool[mode] = rmsg star.lockPool[mode] = rmsg
star.lockPool[mode].wait <- 1 star.lockPool[mode].wait <- 1
star.lockMu.Unlock()
} else { } else {
if star.Sync { if !star.Sync {
go replyFunc(key, rmsg) go replyFunc(key, rmsg)
} else { } else {
replyFunc(key, rmsg) replyFunc(key, rmsg)

Loading…
Cancel
Save