notify/signal_ack.go

242 lines
4.6 KiB
Go
Raw Normal View History

package notify
import (
"errors"
"sync"
"time"
)
var (
errSignalAckCanceled = errors.New("signal ack canceled")
errSignalAckTimeout = errors.New("signal ack timeout")
)
const signalAckShardCount = 64
type signalAckMapKey struct {
scope string
signalID uint64
}
type signalAckShard struct {
mu sync.Mutex
wait map[signalAckMapKey]*signalAckWait
}
type signalAckWait struct {
key signalAckMapKey
scope string
pool *signalAckPool
reply chan error
closeOnce sync.Once
}
type signalAckPool struct {
shards [signalAckShardCount]signalAckShard
}
func newSignalAckPool() *signalAckPool {
pool := &signalAckPool{}
for i := range pool.shards {
pool.shards[i].wait = make(map[signalAckMapKey]*signalAckWait)
}
return pool
}
func signalAckScopeHash(scope string) uint64 {
var hash uint64 = 1469598103934665603
for i := 0; i < len(scope); i++ {
hash ^= uint64(scope[i])
hash *= 1099511628211
}
return hash
}
func (p *signalAckPool) shard(scope string, signalID uint64) *signalAckShard {
if p == nil {
return nil
}
index := int((signalID ^ signalAckScopeHash(scope)) % signalAckShardCount)
return &p.shards[index]
}
func (p *signalAckPool) prepare(scope string, signalID uint64) *signalAckWait {
scope = normalizeFileScope(scope)
wait := &signalAckWait{
key: signalAckMapKey{
scope: scope,
signalID: signalID,
},
scope: scope,
pool: p,
reply: make(chan error, 1),
}
if shard := p.shard(scope, signalID); shard != nil {
shard.mu.Lock()
shard.wait[wait.key] = wait
shard.mu.Unlock()
}
return wait
}
func (p *signalAckPool) deliver(scope string, signalID uint64) bool {
return p.deliverAny([]string{scope}, signalID)
}
func (p *signalAckPool) deliverAny(scopes []string, signalID uint64) bool {
if p == nil {
return false
}
for _, scope := range scopes {
normalized := normalizeFileScope(scope)
shard := p.shard(normalized, signalID)
if shard == nil {
continue
}
key := signalAckMapKey{
scope: normalized,
signalID: signalID,
}
shard.mu.Lock()
wait := shard.wait[key]
if wait != nil {
delete(shard.wait, key)
}
shard.mu.Unlock()
if wait == nil {
continue
}
wait.ack()
return true
}
return false
}
func (w *signalAckWait) ack() {
if w == nil {
return
}
w.closeOnce.Do(func() {
select {
case w.reply <- nil:
default:
}
close(w.reply)
})
}
func (w *signalAckWait) cancel() {
if w == nil {
return
}
if w.pool != nil {
shard := w.pool.shard(w.key.scope, w.key.signalID)
if shard != nil {
shard.mu.Lock()
delete(shard.wait, w.key)
shard.mu.Unlock()
}
}
w.closeReply()
}
func (w *signalAckWait) closeReply() {
if w == nil {
return
}
w.closeOnce.Do(func() {
close(w.reply)
})
}
func (p *signalAckPool) waitPrepared(wait *signalAckWait, timeout time.Duration) error {
if timeout <= 0 {
timeout = defaultSignalAckTimeout
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case err, ok := <-wait.reply:
if !ok {
return errSignalAckCanceled
}
return err
case <-timer.C:
wait.cancel()
return errSignalAckTimeout
}
}
func (p *signalAckPool) closeAll() {
if p == nil {
return
}
for i := range p.shards {
shard := &p.shards[i]
shard.mu.Lock()
waits := make([]*signalAckWait, 0, len(shard.wait))
for key, wait := range shard.wait {
delete(shard.wait, key)
if wait != nil {
waits = append(waits, wait)
}
}
shard.mu.Unlock()
for _, wait := range waits {
wait.closeReply()
}
}
}
func (p *signalAckPool) closeScope(scope string) {
if p == nil {
return
}
scope = normalizeFileScope(scope)
for i := range p.shards {
shard := &p.shards[i]
shard.mu.Lock()
waits := make([]*signalAckWait, 0)
for key, wait := range shard.wait {
if wait != nil && wait.scope == scope {
delete(shard.wait, key)
waits = append(waits, wait)
}
}
shard.mu.Unlock()
for _, wait := range waits {
wait.closeReply()
}
}
}
func (p *signalAckPool) closeScopeFamily(scope string) {
if p == nil {
return
}
base := normalizeFileScope(scope)
for i := range p.shards {
shard := &p.shards[i]
shard.mu.Lock()
waits := make([]*signalAckWait, 0)
for key, wait := range shard.wait {
if wait != nil && scopeBelongsToServerFileScope(wait.scope, base) {
delete(shard.wait, key)
waits = append(waits, wait)
}
}
shard.mu.Unlock()
for _, wait := range waits {
wait.closeReply()
}
}
}
func (c *ClientCommon) getSignalAckPool() *signalAckPool {
return c.getLogicalSessionState().signalAckWaits
}
func (s *ServerCommon) getSignalAckPool() *signalAckPool {
return s.getLogicalSessionState().signalAckWaits
}