package notify import ( "errors" "sync" "time" ) var ( errSignalAckCanceled = errors.New("signal ack canceled") errSignalAckTimeout = errors.New("signal ack timeout") ) const signalAckShardCount = 64 type signalAckMapKey struct { scope string signalID uint64 } type signalAckShard struct { mu sync.Mutex wait map[signalAckMapKey]*signalAckWait } type signalAckWait struct { key signalAckMapKey scope string pool *signalAckPool reply chan error closeOnce sync.Once } type signalAckPool struct { shards [signalAckShardCount]signalAckShard } func newSignalAckPool() *signalAckPool { pool := &signalAckPool{} for i := range pool.shards { pool.shards[i].wait = make(map[signalAckMapKey]*signalAckWait) } return pool } func signalAckScopeHash(scope string) uint64 { var hash uint64 = 1469598103934665603 for i := 0; i < len(scope); i++ { hash ^= uint64(scope[i]) hash *= 1099511628211 } return hash } func (p *signalAckPool) shard(scope string, signalID uint64) *signalAckShard { if p == nil { return nil } index := int((signalID ^ signalAckScopeHash(scope)) % signalAckShardCount) return &p.shards[index] } func (p *signalAckPool) prepare(scope string, signalID uint64) *signalAckWait { scope = normalizeFileScope(scope) wait := &signalAckWait{ key: signalAckMapKey{ scope: scope, signalID: signalID, }, scope: scope, pool: p, reply: make(chan error, 1), } if shard := p.shard(scope, signalID); shard != nil { shard.mu.Lock() shard.wait[wait.key] = wait shard.mu.Unlock() } return wait } func (p *signalAckPool) deliver(scope string, signalID uint64) bool { return p.deliverAny([]string{scope}, signalID) } func (p *signalAckPool) deliverAny(scopes []string, signalID uint64) bool { if p == nil { return false } for _, scope := range scopes { normalized := normalizeFileScope(scope) shard := p.shard(normalized, signalID) if shard == nil { continue } key := signalAckMapKey{ scope: normalized, signalID: signalID, } shard.mu.Lock() wait := shard.wait[key] if wait != nil { delete(shard.wait, key) } shard.mu.Unlock() if wait == nil { continue } wait.ack() return true } return false } func (w *signalAckWait) ack() { if w == nil { return } w.closeOnce.Do(func() { select { case w.reply <- nil: default: } close(w.reply) }) } func (w *signalAckWait) cancel() { if w == nil { return } if w.pool != nil { shard := w.pool.shard(w.key.scope, w.key.signalID) if shard != nil { shard.mu.Lock() delete(shard.wait, w.key) shard.mu.Unlock() } } w.closeReply() } func (w *signalAckWait) closeReply() { if w == nil { return } w.closeOnce.Do(func() { close(w.reply) }) } func (p *signalAckPool) waitPrepared(wait *signalAckWait, timeout time.Duration) error { if timeout <= 0 { timeout = defaultSignalAckTimeout } timer := time.NewTimer(timeout) defer timer.Stop() select { case err, ok := <-wait.reply: if !ok { return errSignalAckCanceled } return err case <-timer.C: wait.cancel() return errSignalAckTimeout } } func (p *signalAckPool) closeAll() { if p == nil { return } for i := range p.shards { shard := &p.shards[i] shard.mu.Lock() waits := make([]*signalAckWait, 0, len(shard.wait)) for key, wait := range shard.wait { delete(shard.wait, key) if wait != nil { waits = append(waits, wait) } } shard.mu.Unlock() for _, wait := range waits { wait.closeReply() } } } func (p *signalAckPool) closeScope(scope string) { if p == nil { return } scope = normalizeFileScope(scope) for i := range p.shards { shard := &p.shards[i] shard.mu.Lock() waits := make([]*signalAckWait, 0) for key, wait := range shard.wait { if wait != nil && wait.scope == scope { delete(shard.wait, key) waits = append(waits, wait) } } shard.mu.Unlock() for _, wait := range waits { wait.closeReply() } } } func (p *signalAckPool) closeScopeFamily(scope string) { if p == nil { return } base := normalizeFileScope(scope) for i := range p.shards { shard := &p.shards[i] shard.mu.Lock() waits := make([]*signalAckWait, 0) for key, wait := range shard.wait { if wait != nil && scopeBelongsToServerFileScope(wait.scope, base) { delete(shard.wait, key) waits = append(waits, wait) } } shard.mu.Unlock() for _, wait := range waits { wait.closeReply() } } } func (c *ClientCommon) getSignalAckPool() *signalAckPool { return c.getLogicalSessionState().signalAckWaits } func (s *ServerCommon) getSignalAckPool() *signalAckPool { return s.getLogicalSessionState().signalAckWaits }