notify/stream_runtime.go

202 lines
4.3 KiB
Go
Raw Normal View History

package notify
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
)
type streamRuntime struct {
rolePrefix string
seq atomic.Uint64
dataSeq atomic.Uint64
mu sync.RWMutex
handler func(StreamAcceptInfo) error
streams map[string]*streamHandle
data map[string]*streamHandle
cfg streamConfig
flow *streamFlowController
}
func newStreamRuntime(rolePrefix string) *streamRuntime {
cfg := defaultStreamConfig()
return &streamRuntime{
rolePrefix: rolePrefix,
streams: make(map[string]*streamHandle),
data: make(map[string]*streamHandle),
cfg: cfg,
flow: newStreamFlowController(cfg),
}
}
func (r *streamRuntime) nextID() string {
if r == nil {
return ""
}
return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1))
}
func (r *streamRuntime) nextDataID() uint64 {
if r == nil {
return 0
}
return r.dataSeq.Add(1)
}
func (r *streamRuntime) setHandler(fn func(StreamAcceptInfo) error) {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.handler = fn
}
func (r *streamRuntime) handlerSnapshot() func(StreamAcceptInfo) error {
if r == nil {
return nil
}
r.mu.RLock()
defer r.mu.RUnlock()
return r.handler
}
func (r *streamRuntime) register(scope string, stream *streamHandle) error {
if r == nil {
return errStreamRuntimeNil
}
if stream == nil || stream.id == "" {
return errStreamIDEmpty
}
key := streamRuntimeKey(scope, stream.id)
dataKey := streamRuntimeDataKey(scope, stream.dataID)
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.streams[key]; ok {
return errStreamAlreadyExists
}
if stream.dataID != 0 {
if _, ok := r.data[dataKey]; ok {
return errStreamAlreadyExists
}
r.data[dataKey] = stream
}
r.streams[key] = stream
return nil
}
func (r *streamRuntime) lookup(scope string, streamID string) (*streamHandle, bool) {
if r == nil || streamID == "" {
return nil, false
}
key := streamRuntimeKey(scope, streamID)
r.mu.RLock()
defer r.mu.RUnlock()
stream, ok := r.streams[key]
return stream, ok
}
func (r *streamRuntime) lookupByDataID(scope string, dataID uint64) (*streamHandle, bool) {
if r == nil || dataID == 0 {
return nil, false
}
key := streamRuntimeDataKey(scope, dataID)
r.mu.RLock()
defer r.mu.RUnlock()
stream, ok := r.data[key]
return stream, ok
}
func (r *streamRuntime) remove(scope string, streamID string) {
if r == nil || streamID == "" {
return
}
key := streamRuntimeKey(scope, streamID)
r.mu.Lock()
defer r.mu.Unlock()
if stream := r.streams[key]; stream != nil && stream.dataID != 0 {
delete(r.data, streamRuntimeDataKey(scope, stream.dataID))
}
delete(r.streams, key)
}
func (r *streamRuntime) acquireOutbound(ctx context.Context, size int) (func(), error) {
if r == nil || r.flow == nil {
return func() {}, nil
}
return r.flow.acquire(ctx, size)
}
func (r *streamRuntime) snapshots() []StreamSnapshot {
if r == nil {
return nil
}
r.mu.RLock()
snapshots := make([]StreamSnapshot, 0, len(r.streams))
for _, stream := range r.streams {
if stream == nil {
continue
}
snapshots = append(snapshots, stream.snapshot())
}
r.mu.RUnlock()
sortStreamSnapshots(snapshots)
return snapshots
}
func (r *streamRuntime) closeAll(err error) {
r.closeMatching(func(string) bool { return true }, err)
}
func (r *streamRuntime) closeScope(scope string, err error) {
scope = normalizeFileScope(scope)
r.closeMatching(func(key string) bool {
return strings.HasPrefix(key, scope+"\x00")
}, err)
}
func (r *streamRuntime) closeMatching(match func(string) bool, err error) {
if r == nil || match == nil {
return
}
resetErr := streamRuntimeCloseError(err)
r.mu.RLock()
streams := make([]*streamHandle, 0, len(r.streams))
for key, stream := range r.streams {
if stream == nil || !match(key) {
continue
}
streams = append(streams, stream)
}
r.mu.RUnlock()
for _, stream := range streams {
stream.markReset(resetErr)
}
}
func streamRuntimeKey(scope string, streamID string) string {
return normalizeFileScope(scope) + "\x00" + streamID
}
func streamRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func (c *ClientCommon) getStreamRuntime() *streamRuntime {
if c == nil {
return nil
}
return c.streamRuntime
}
func (s *ServerCommon) getStreamRuntime() *streamRuntime {
if s == nil {
return nil
}
return s.streamRuntime
}