diff --git a/client.go b/client.go index 0f0f5ab..eda1e56 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,9 @@ package notify import ( "context" + "errors" + "fmt" + "math/rand" "net" "strings" "time" @@ -26,13 +29,16 @@ type StarNotifyC struct { // Queue 是用来处理收发信息的简单消息队列 Queue *starainrt.StarQueue // Online 当前链接是否处于活跃状态 - Online bool + Online bool + lockPool map[string]CMsg } // CMsg 指明当前客户端被通知的关键字 type CMsg struct { Key string Value string + mode string + wait chan int } func (star *StarNotifyC) starinitc() { @@ -43,6 +49,7 @@ func (star *StarNotifyC) starinitc() { star.Stop = make(chan int, 5) star.clientSign = make(map[string]chan string) star.Online = false + star.lockPool = make(map[string]CMsg) star.Queue.RestoreDuration(time.Second * 2) } @@ -107,13 +114,84 @@ func NewNotifyC(netype, value string) (*StarNotifyC, error) { // Send 用于向Server端发送数据 func (star *StarNotifyC) Send(name string) error { - _, err := star.Connc.Write(star.Queue.BuildMessage([]byte(name))) - return err + return star.SendValue(name, "") } // SendValue 用于向Server端发送key-value类型数据 func (star *StarNotifyC) SendValue(name, value string) error { - _, err := star.Connc.Write(star.Queue.BuildMessage([]byte(name + "||" + value))) + var err error + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte("pa" + "||" + string(key) + "||" + value))) + return err +} + +func (star *StarNotifyC) trim(name string) string { + var slash bool = false + var key []byte + for _, v := range []byte(name) { + if v == byte(92) && !slash { + slash = true + continue + } + slash = false + key = append(key, v) + } + return string(key) +} + +// SendValueWait 用于向Server端发送key-value类型数据并等待结果返回,此结果不会通过标准返回流程处理 +func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) (CMsg, error) { + var err error + var tmceed <-chan time.Time + if star.UseChannel { + return CMsg{}, errors.New("Do Not Use UseChannel Mode!") + } + rand.Seed(time.Now().UnixNano()) + mode := "cr" + fmt.Sprintf("%05d", rand.Intn(99999)) + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(mode + "||" + string(key) + "||" + value))) + if err != nil { + return CMsg{}, err + } + if int64(tmout) > 0 { + tmceed = time.After(tmout) + } + var source CMsg + source.wait = make(chan int, 2) + star.lockPool[mode] = source + select { + case <-source.wait: + res := star.lockPool[mode] + delete(star.lockPool, mode) + return res, nil + case <-tmceed: + return CMsg{}, errors.New("Time Exceed") + } +} + +// ReplyMsg 用于向Server端Reply信息 +func (star *StarNotifyC) ReplyMsg(data CMsg, name, value string) error { + var err error + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(data.mode + "||" + string(key) + "||" + value))) return err } @@ -134,19 +212,38 @@ func (star *StarNotifyC) cnotify() { star.Online = false return } - strs := strings.SplitN(string(data.Msg), "||", 2) - if len(strs) < 2 { + strs := strings.SplitN(string(data.Msg), "||", 3) + if len(strs) < 3 { continue } + strs[1] = star.trim(strs[1]) if star.UseChannel { - go star.store(strs[0], strs[1]) + go star.store(strs[1], strs[2]) } else { - key, value := strs[0], strs[1] - if msg, ok := star.FuncLists[key]; ok { - go msg(CMsg{key, value}) + mode, key, value := strs[0], strs[1], strs[2] + if mode[0:2] != "cr" { + if msg, ok := star.FuncLists[key]; ok { + go msg(CMsg{key, value, mode, nil}) + } else { + if star.defaultFunc != nil { + go star.defaultFunc(CMsg{key, value, mode, nil}) + } + } } else { - if star.defaultFunc != nil { - go star.defaultFunc(CMsg{key, value}) + if sa, ok := star.lockPool[mode]; ok { + sa.Key = key + sa.Value = value + sa.mode = mode + star.lockPool[mode] = sa + sa.wait <- 1 + } else { + if msg, ok := star.FuncLists[key]; ok { + go msg(CMsg{key, value, mode, nil}) + } else { + if star.defaultFunc != nil { + go star.defaultFunc(CMsg{key, value, mode, nil}) + } + } } } } @@ -160,6 +257,8 @@ func (star *StarNotifyC) ClientStop() { } star.cancel() star.Stop <- 1 + star.Stop <- 1 + star.Stop <- 1 } // SetNotify 用于设置关键词的调用函数 @@ -168,6 +267,6 @@ func (star *StarNotifyC) SetNotify(name string, data func(CMsg)) { } // SetDefaultNotify 用于设置默认关键词的调用函数 -func (star *StarNotifyC) SetDefaultNotify(name string, data func(CMsg)) { +func (star *StarNotifyC) SetDefaultNotify(data func(CMsg)) { star.defaultFunc = data } diff --git a/client_test.go b/client_test.go index 433ecb7..b07d7ba 100644 --- a/client_test.go +++ b/client_test.go @@ -47,8 +47,8 @@ func Test_nochannel(t *testing.T) { server.SetNotify("nihao", func(data SMsg) string { fmt.Println("server recv:", data.Key, data.Value) if data.Value != "" { - data.Reply("nba") - return "nb" + data.Reply("nbaz") + return "" } return "" }) @@ -73,3 +73,77 @@ func Test_nochannel(t *testing.T) { client.ClientStop() time.Sleep(time.Second * 3) } + +func Test_pipec(t *testing.T) { + server, err := NewNotifyS("tcp", "127.0.0.1:1926") + if err != nil { + fmt.Println(err) + return + } + server.SetNotify("ni\\||hao", func(data SMsg) string { + fmt.Println("server recv:", data.Key, data.Value, data.mode) + if data.Value != "" { + data.Reply("nba") + return "" + } + return "" + }) + client, err := NewNotifyC("tcp", "127.0.0.1:1926") + if err != nil { + fmt.Println(err) + return + } + client.UseChannel = false + sa, err := client.SendValueWait("ni\\||hao", "lalaeee", time.Second*10) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(sa) + fmt.Println("sukidesu") + time.Sleep(time.Second * 3) + server.ServerStop() + <-client.Stop + client.ClientStop() + time.Sleep(time.Second * 2) +} + +func Test_pips(t *testing.T) { + var testmsg SMsg + server, err := NewNotifyS("udp", "127.0.0.1:1926") + if err != nil { + fmt.Println(err) + return + } + server.SetNotify("nihao", func(data SMsg) string { + fmt.Println("server recv:", data.Key, data.Value, data.mode) + testmsg = data + if data.Value != "" { + data.Reply("nbaz") + return "" + } + return "" + }) + client, err := NewNotifyC("udp", "127.0.0.1:1926") + if err != nil { + fmt.Println(err) + return + } + //time.Sleep(time.Second * 10) + client.UseChannel = false + client.SetNotify("nihao", func(data CMsg) { + fmt.Println("client recv:", data.Key, data.Value, data.mode) + if data.mode != "pa" { + time.Sleep(time.Millisecond * 1200) + client.ReplyMsg(data, "nihao", "dsb") + } + }) + client.SendValue("nihao", "lalala") + time.Sleep(time.Second * 3) + fmt.Println(server.SendWait(testmsg, "nihao", "wozuinb", time.Second*20)) + fmt.Println("sakura") + server.ServerStop() + <-client.Stop + client.ClientStop() + time.Sleep(time.Second * 3) +} diff --git a/server.go b/server.go index abd5898..bc77729 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,9 @@ package notify import ( "context" + "errors" + "fmt" + "math/rand" "net" "strings" "time" @@ -23,11 +26,15 @@ type StarNotifyS struct { // FuncLists 记录了被通知项所记录的函数 FuncLists map[string]func(SMsg) string defaultFunc func(SMsg) string + Connected func(SMsg) string stopSign context.Context cancel context.CancelFunc connPool map[string]net.Conn + lockPool map[string]SMsg udpPool map[string]*net.UDPAddr isUDP bool + // Stop 停止信 号 + Stop chan int // UDPConn UDP监听 UDPConn *net.UDPConn // Online 当前链接是否处于活跃状态 @@ -48,15 +55,40 @@ type SMsg struct { Value string UDP *net.UDPAddr uconn *net.UDPConn + mode string + wait chan int +} + +// GetConnPool 获取所有Client端信息 +func (star *StarNotifyS) GetConnPool() []SMsg { + var result []SMsg + for _, v := range star.connPool { + result = append(result, SMsg{Conn: v, mode: "pa"}) + } + for _, v := range star.udpPool { + result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0"}) + } + return result +} + +func (nmsg *SMsg) addSlash(name string) string { + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + return string(key) } // Reply 用于向client端回复数据 func (nmsg *SMsg) Reply(msg string) error { var err error if nmsg.uconn == nil { - _, err = nmsg.Conn.Write(builder.BuildMessage([]byte(nmsg.Key + "||" + msg))) + _, err = nmsg.Conn.Write(builder.BuildMessage([]byte(nmsg.mode + "||" + nmsg.addSlash(nmsg.Key) + "||" + msg))) } else { - _, err = nmsg.uconn.WriteToUDP(builder.BuildMessage([]byte(nmsg.Key+"||"+msg)), nmsg.UDP) + _, err = nmsg.uconn.WriteToUDP(builder.BuildMessage([]byte(nmsg.mode+"||"+nmsg.addSlash(nmsg.Key)+"||"+msg)), nmsg.UDP) } return err } @@ -65,19 +97,50 @@ func (nmsg *SMsg) Reply(msg string) error { func (nmsg *SMsg) Send(key, value string) error { var err error if nmsg.uconn == nil { - _, err = nmsg.Conn.Write(builder.BuildMessage([]byte(key + "||" + value))) + _, err = nmsg.Conn.Write(builder.BuildMessage([]byte("pa||" + nmsg.addSlash(key) + "||" + value))) } else { - _, err = nmsg.uconn.WriteToUDP(builder.BuildMessage([]byte(key+"||"+value)), nmsg.UDP) + _, err = nmsg.uconn.WriteToUDP(builder.BuildMessage([]byte("pa||"+nmsg.addSlash(key)+"||"+value)), nmsg.UDP) } return err } +// SendWait 用于向client端发送key-value数据,并等待 +func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Duration) (SMsg, error) { + var err error + var tmceed <-chan time.Time + rand.Seed(time.Now().UnixNano()) + mode := "sr" + fmt.Sprintf("%05d", rand.Intn(99999)) + if source.uconn == nil { + _, err = source.Conn.Write(builder.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value))) + } else { + _, err = source.uconn.WriteToUDP(builder.BuildMessage([]byte(mode+"||"+source.addSlash(key)+"||"+value)), source.UDP) + } + if err != nil { + return SMsg{}, err + } + if int64(tmout) > 0 { + tmceed = time.After(tmout) + } + source.wait = make(chan int, 2) + star.lockPool[mode] = source + select { + case <-source.wait: + res := star.lockPool[mode] + delete(star.lockPool, mode) + return res, nil + case <-tmceed: + return SMsg{}, errors.New("Time Exceed") + } +} + func (star *StarNotifyS) starinits() { star.stopSign, star.cancel = context.WithCancel(context.Background()) star.Queue = starainrt.NewQueue() star.udpPool = make(map[string]*net.UDPAddr) star.FuncLists = make(map[string]func(SMsg) string) star.connPool = make(map[string]net.Conn) + star.lockPool = make(map[string]SMsg) + star.Stop = make(chan int, 5) star.Online = false star.Queue.RestoreDuration(time.Second * 2) } @@ -119,6 +182,11 @@ func doudps(netype, value string) (*StarNotifyS, error) { n, addr, err := star.UDPConn.ReadFromUDP(buf) if n != 0 { star.Queue.ParseMessage(buf[0:n], addr) + if _, ok := star.udpPool[addr.String()]; !ok { + if star.Connected != nil { + go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn}) + } + } star.udpPool[addr.String()] = addr } if err != nil { @@ -185,6 +253,9 @@ func notudps(netype, value string) (*StarNotifyS, error) { } }(conn) star.connPool[conn.RemoteAddr().String()] = conn + if star.Connected != nil { + go star.Connected(SMsg{Conn: conn}) + } } }() star.Online = true @@ -197,10 +268,24 @@ func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) { } // SetDefaultNotify 用于设置默认关键词的调用函数 -func (star *StarNotifyS) SetDefaultNotify(name string, data func(SMsg) string) { +func (star *StarNotifyS) SetDefaultNotify(data func(SMsg) string) { star.defaultFunc = data } +func (star *StarNotifyS) trim(name string) string { + var slash bool = false + var key []byte + for _, v := range []byte(name) { + if v == byte(92) && !slash { + slash = true + continue + } + slash = false + key = append(key, v) + } + return string(key) +} + func (star *StarNotifyS) notify() { for { select { @@ -213,46 +298,72 @@ func (star *StarNotifyS) notify() { time.Sleep(time.Millisecond * 20) continue } - key, value := analyseData(string(data.Msg)) + mode, key, value := star.analyseData(string(data.Msg)) var rmsg SMsg if !star.isUDP { - rmsg = SMsg{data.Conn.(net.Conn), key, value, nil, nil} + rmsg = SMsg{data.Conn.(net.Conn), key, value, nil, nil, mode, nil} } else { - rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn} + rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil} if key == "b612ryzstop" { delete(star.udpPool, rmsg.UDP.String()) continue } } - go func() { - if msg, ok := star.FuncLists[key]; ok { - sdata := msg(rmsg) - if sdata == "" { - return - } - rmsg.Reply(sdata) - } else { - if star.defaultFunc != nil { - sdata := star.defaultFunc(rmsg) + if mode[0:2] != "sr" { + go func() { + if msg, ok := star.FuncLists[key]; ok { + sdata := msg(rmsg) if sdata == "" { return } rmsg.Reply(sdata) + } else { + if star.defaultFunc != nil { + sdata := star.defaultFunc(rmsg) + if sdata == "" { + return + } + rmsg.Reply(sdata) + } } + }() + } else { + if sa, ok := star.lockPool[mode]; ok { + rmsg.wait = sa.wait + star.lockPool[mode] = rmsg + star.lockPool[mode].wait <- 1 + } else { + go func() { + if msg, ok := star.FuncLists[key]; ok { + sdata := msg(rmsg) + if sdata == "" { + return + } + rmsg.Reply(sdata) + } else { + if star.defaultFunc != nil { + sdata := star.defaultFunc(rmsg) + if sdata == "" { + return + } + rmsg.Reply(sdata) + } + } + }() } - }() + } } } -func analyseData(msg string) (key, value string) { - slice := strings.SplitN(msg, "||", 2) - if len(slice) == 1 { - return msg, "" - } - return slice[0], slice[1] +func (star *StarNotifyS) analyseData(msg string) (mode, key, value string) { + slice := strings.SplitN(msg, "||", 3) + return slice[0], star.trim(slice[1]), slice[2] } // ServerStop 用于终止Server端运行 func (star *StarNotifyS) ServerStop() { star.cancel() + star.Stop <- 1 + star.Stop <- 1 + star.Stop <- 1 }