notify/stream_control.go

610 lines
17 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"errors"
"time"
)
type StreamOpenRequest struct {
StreamID string
DataID uint64
Channel StreamChannel
Metadata StreamMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
}
type StreamOpenResponse struct {
StreamID string
DataID uint64
Accepted bool
TransportGeneration uint64
Error string
}
type StreamCloseRequest struct {
StreamID string
Full bool
}
type StreamCloseResponse struct {
StreamID string
Accepted bool
Error string
}
type StreamResetRequest struct {
StreamID string
DataID uint64
Error string
}
type StreamResetResponse struct {
StreamID string
Accepted bool
Error string
}
func bindClientStreamControl(c *ClientCommon) {
if c == nil {
return
}
c.SetLink(StreamOpenSignalKey, func(msg *Message) {
c.handleInboundStreamOpen(msg)
})
c.SetLink(StreamCloseSignalKey, func(msg *Message) {
c.handleInboundStreamClose(msg)
})
c.SetLink(StreamResetSignalKey, func(msg *Message) {
c.handleInboundStreamReset(msg)
})
}
func bindServerStreamControl(s *ServerCommon) {
if s == nil {
return
}
s.SetLink(StreamOpenSignalKey, func(msg *Message) {
s.handleInboundStreamOpen(msg)
})
s.SetLink(StreamCloseSignalKey, func(msg *Message) {
s.handleInboundStreamClose(msg)
})
s.SetLink(StreamResetSignalKey, func(msg *Message) {
s.handleInboundStreamReset(msg)
})
}
func (c *ClientCommon) handleInboundStreamOpen(msg *Message) {
req, err := decodeStreamOpenRequest(msg)
resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := c.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
scope := clientFileScope()
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID
}
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
stream.setClientSnapshotOwner(c)
stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
if err := runtime.register(scope, stream); err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if claimed, err := c.claimInboundRecordStream(stream); claimed {
if err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
return
}
if claimed, err := c.claimInboundTransferStream(stream); claimed {
if err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
stream.markReset(errStreamHandlerNotConfigured)
resp.Error = errStreamHandlerNotConfigured.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
info := StreamAcceptInfo{
ID: stream.ID(),
DataID: stream.dataIDSnapshot(),
Channel: stream.Channel(),
Metadata: stream.Metadata(),
TransportGeneration: stream.TransportGeneration(),
Stream: stream,
}
if err := handler(info); err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.DataID = stream.dataIDSnapshot()
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundStreamOpen(msg *Message) {
req, err := decodeStreamOpenRequest(msg)
resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := s.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
if logical == nil {
resp.Error = errStreamLogicalConnNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
transport := messageTransportConnSnapshot(msg)
scope := serverFileScope(logical)
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID
}
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, streamTransportGeneration(logical, transport), serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot())
if err := runtime.register(scope, stream); err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if claimed, err := s.claimInboundRecordStream(logical, transport, stream); claimed {
if err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
return
}
if claimed, err := s.claimInboundTransferStream(logical, transport, stream); claimed {
if err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
stream.markReset(errStreamHandlerNotConfigured)
resp.Error = errStreamHandlerNotConfigured.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
info := StreamAcceptInfo{
ID: stream.ID(),
DataID: stream.dataIDSnapshot(),
Channel: stream.Channel(),
Metadata: stream.Metadata(),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: stream.TransportGeneration(),
Stream: stream,
}
if err := handler(info); err != nil {
stream.markReset(err)
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.DataID = stream.dataIDSnapshot()
resp.TransportGeneration = stream.TransportGeneration()
replyStreamControlIfNeeded(msg, resp)
}
func (c *ClientCommon) handleInboundStreamClose(msg *Message) {
req, err := decodeStreamCloseRequest(msg)
resp := StreamCloseResponse{StreamID: req.StreamID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := c.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
stream, ok := runtime.lookup(clientFileScope(), req.StreamID)
if !ok {
resp.Error = errStreamNotFound.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if req.Full {
stream.markPeerClosed()
} else {
stream.markRemoteClosed()
}
resp.Accepted = true
replyStreamControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundStreamClose(msg *Message) {
req, err := decodeStreamCloseRequest(msg)
resp := StreamCloseResponse{StreamID: req.StreamID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := s.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
stream, ok := runtime.lookup(scope, req.StreamID)
if !ok {
resp.Error = errStreamNotFound.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if req.Full {
stream.markPeerClosed()
} else {
stream.markRemoteClosed()
}
resp.Accepted = true
replyStreamControlIfNeeded(msg, resp)
}
func (c *ClientCommon) handleInboundStreamReset(msg *Message) {
req, err := decodeStreamResetRequest(msg)
resp := StreamResetResponse{StreamID: req.StreamID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := c.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
stream, ok := runtime.lookup(clientFileScope(), req.StreamID)
if !ok && req.DataID != 0 {
stream, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
}
if !ok {
resp.Error = errStreamNotFound.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if resp.StreamID == "" {
resp.StreamID = stream.ID()
}
stream.markReset(streamResetError(streamRemoteResetError(req.Error)))
resp.Accepted = true
replyStreamControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundStreamReset(msg *Message) {
req, err := decodeStreamResetRequest(msg)
resp := StreamResetResponse{StreamID: req.StreamID}
if err != nil {
resp.Error = err.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
runtime := s.getStreamRuntime()
if runtime == nil {
resp.Error = errStreamRuntimeNil.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
stream, ok := runtime.lookup(scope, req.StreamID)
if !ok && req.DataID != 0 {
stream, ok = runtime.lookupByDataID(scope, req.DataID)
}
if !ok {
resp.Error = errStreamNotFound.Error()
replyStreamControlIfNeeded(msg, resp)
return
}
if resp.StreamID == "" {
resp.StreamID = stream.ID()
}
stream.markReset(streamResetError(streamRemoteResetError(req.Error)))
resp.Accepted = true
replyStreamControlIfNeeded(msg, resp)
}
func replyStreamControlIfNeeded(msg *Message, value interface{}) {
if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) {
return
}
_ = msg.ReplyObj(value)
}
func sendStreamOpenClient(ctx context.Context, c Client, req StreamOpenRequest) (StreamOpenResponse, error) {
if c == nil {
return StreamOpenResponse{}, errStreamClientNil
}
msg, err := c.SendObjCtx(ctx, StreamOpenSignalKey, req)
if err != nil {
return StreamOpenResponse{}, err
}
return decodeStreamOpenResponse(msg)
}
func sendStreamOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamOpenRequest) (StreamOpenResponse, error) {
if s == nil {
return StreamOpenResponse{}, errStreamServerNil
}
if logical == nil {
return StreamOpenResponse{}, errStreamLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, StreamOpenSignalKey, req)
if err != nil {
return StreamOpenResponse{}, err
}
return decodeStreamOpenResponse(msg)
}
func sendStreamOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamOpenRequest) (StreamOpenResponse, error) {
if s == nil {
return StreamOpenResponse{}, errStreamServerNil
}
if transport == nil {
return StreamOpenResponse{}, errStreamTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, StreamOpenSignalKey, req)
if err != nil {
return StreamOpenResponse{}, err
}
return decodeStreamOpenResponse(msg)
}
func sendStreamCloseClient(ctx context.Context, c Client, req StreamCloseRequest) (StreamCloseResponse, error) {
if c == nil {
return StreamCloseResponse{}, errStreamClientNil
}
msg, err := c.SendObjCtx(ctx, StreamCloseSignalKey, req)
if err != nil {
return StreamCloseResponse{}, err
}
return decodeStreamCloseResponse(msg)
}
func sendStreamCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamCloseRequest) (StreamCloseResponse, error) {
if s == nil {
return StreamCloseResponse{}, errStreamServerNil
}
if logical == nil {
return StreamCloseResponse{}, errStreamLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, StreamCloseSignalKey, req)
if err != nil {
return StreamCloseResponse{}, err
}
return decodeStreamCloseResponse(msg)
}
func sendStreamCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamCloseRequest) (StreamCloseResponse, error) {
if s == nil {
return StreamCloseResponse{}, errStreamServerNil
}
if transport == nil {
return StreamCloseResponse{}, errStreamTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, StreamCloseSignalKey, req)
if err != nil {
return StreamCloseResponse{}, err
}
return decodeStreamCloseResponse(msg)
}
func sendStreamResetClient(ctx context.Context, c Client, req StreamResetRequest) (StreamResetResponse, error) {
if c == nil {
return StreamResetResponse{}, errStreamClientNil
}
msg, err := c.SendObjCtx(ctx, StreamResetSignalKey, req)
if err != nil {
return StreamResetResponse{}, err
}
return decodeStreamResetResponse(msg)
}
func sendStreamResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamResetRequest) (StreamResetResponse, error) {
if s == nil {
return StreamResetResponse{}, errStreamServerNil
}
if logical == nil {
return StreamResetResponse{}, errStreamLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, StreamResetSignalKey, req)
if err != nil {
return StreamResetResponse{}, err
}
return decodeStreamResetResponse(msg)
}
func sendStreamResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamResetRequest) (StreamResetResponse, error) {
if s == nil {
return StreamResetResponse{}, errStreamServerNil
}
if transport == nil {
return StreamResetResponse{}, errStreamTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, StreamResetSignalKey, req)
if err != nil {
return StreamResetResponse{}, err
}
return decodeStreamResetResponse(msg)
}
func decodeStreamOpenRequest(msg *Message) (StreamOpenRequest, error) {
var req StreamOpenRequest
if msg == nil {
return StreamOpenRequest{}, errStreamIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return StreamOpenRequest{}, err
}
req = normalizeStreamOpenRequest(req)
if req.StreamID == "" {
return StreamOpenRequest{}, errStreamIDEmpty
}
return req, nil
}
func decodeStreamCloseRequest(msg *Message) (StreamCloseRequest, error) {
var req StreamCloseRequest
if msg == nil {
return StreamCloseRequest{}, errStreamIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return StreamCloseRequest{}, err
}
if req.StreamID == "" {
return StreamCloseRequest{}, errStreamIDEmpty
}
return req, nil
}
func decodeStreamResetRequest(msg *Message) (StreamResetRequest, error) {
var req StreamResetRequest
if msg == nil {
return StreamResetRequest{}, errStreamIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return StreamResetRequest{}, err
}
if req.StreamID == "" {
return StreamResetRequest{}, errStreamIDEmpty
}
return req, nil
}
func decodeStreamOpenResponse(msg Message) (StreamOpenResponse, error) {
var resp StreamOpenResponse
if err := msg.Value.Orm(&resp); err != nil {
return StreamOpenResponse{}, err
}
return resp, streamControlResultError("open", resp.Accepted, resp.Error, nil)
}
func decodeStreamCloseResponse(msg Message) (StreamCloseResponse, error) {
var resp StreamCloseResponse
if err := msg.Value.Orm(&resp); err != nil {
return StreamCloseResponse{}, err
}
return resp, streamControlResultError("close", resp.Accepted, resp.Error, nil)
}
func decodeStreamResetResponse(msg Message) (StreamResetResponse, error) {
var resp StreamResetResponse
if err := msg.Value.Orm(&resp); err != nil {
return StreamResetResponse{}, err
}
return resp, streamControlResultError("reset", resp.Accepted, resp.Error, nil)
}
func streamControlResultError(op string, accepted bool, message string, callErr error) error {
if callErr != nil {
return callErr
}
if message != "" {
return streamControlMessageError(message)
}
if accepted {
return nil
}
if op == "open" {
return errStreamRejected
}
return errors.New("stream " + op + " rejected")
}
func streamControlMessageError(message string) error {
switch message {
case errStreamNotFound.Error():
return errStreamNotFound
case errStreamAlreadyExists.Error():
return errStreamAlreadyExists
case errStreamHandlerNotConfigured.Error():
return errStreamHandlerNotConfigured
case errStreamLogicalConnNil.Error():
return errStreamLogicalConnNil
case errStreamTransportNil.Error():
return errStreamTransportNil
case errStreamRuntimeNil.Error():
return errStreamRuntimeNil
case errStreamIDEmpty.Error():
return errStreamIDEmpty
default:
return errors.New(message)
}
}
func streamTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
if transport != nil {
return transport.TransportGeneration()
}
if logical != nil {
return logical.transportGenerationSnapshot()
}
return 0
}
func streamRemoteResetError(message string) error {
if message == "" {
return errStreamReset
}
return errors.New(message)
}