package notify import ( "context" "errors" "time" ) type StreamOpenRequest struct { StreamID string DataID uint64 Channel StreamChannel Metadata StreamMetadata ReadTimeout time.Duration WriteTimeout time.Duration } type StreamOpenResponse struct { StreamID string DataID uint64 Accepted bool TransportGeneration uint64 Error string } type StreamCloseRequest struct { StreamID string Full bool } type StreamCloseResponse struct { StreamID string Accepted bool Error string } type StreamResetRequest struct { StreamID string DataID uint64 Error string } type StreamResetResponse struct { StreamID string Accepted bool Error string } func bindClientStreamControl(c *ClientCommon) { if c == nil { return } c.SetLink(StreamOpenSignalKey, func(msg *Message) { c.handleInboundStreamOpen(msg) }) c.SetLink(StreamCloseSignalKey, func(msg *Message) { c.handleInboundStreamClose(msg) }) c.SetLink(StreamResetSignalKey, func(msg *Message) { c.handleInboundStreamReset(msg) }) } func bindServerStreamControl(s *ServerCommon) { if s == nil { return } s.SetLink(StreamOpenSignalKey, func(msg *Message) { s.handleInboundStreamOpen(msg) }) s.SetLink(StreamCloseSignalKey, func(msg *Message) { s.handleInboundStreamClose(msg) }) s.SetLink(StreamResetSignalKey, func(msg *Message) { s.handleInboundStreamReset(msg) }) } func (c *ClientCommon) handleInboundStreamOpen(msg *Message) { req, err := decodeStreamOpenRequest(msg) resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := c.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } scope := clientFileScope() if req.DataID == 0 { req.DataID = runtime.nextDataID() resp.DataID = req.DataID } stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) stream.setClientSnapshotOwner(c) stream.setAddrSnapshot(c.clientStreamAddrSnapshot()) if err := runtime.register(scope, stream); err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } if claimed, err := c.claimInboundRecordStream(stream); claimed { if err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) return } if claimed, err := c.claimInboundTransferStream(stream); claimed { if err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) return } handler := runtime.handlerSnapshot() if handler == nil { stream.markReset(errStreamHandlerNotConfigured) resp.Error = errStreamHandlerNotConfigured.Error() replyStreamControlIfNeeded(msg, resp) return } info := StreamAcceptInfo{ ID: stream.ID(), DataID: stream.dataIDSnapshot(), Channel: stream.Channel(), Metadata: stream.Metadata(), TransportGeneration: stream.TransportGeneration(), Stream: stream, } if err := handler(info); err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.DataID = stream.dataIDSnapshot() resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) } func (s *ServerCommon) handleInboundStreamOpen(msg *Message) { req, err := decodeStreamOpenRequest(msg) resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := s.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } logical := messageLogicalConnSnapshot(msg) if logical == nil { resp.Error = errStreamLogicalConnNil.Error() replyStreamControlIfNeeded(msg, resp) return } transport := messageTransportConnSnapshot(msg) scope := serverFileScope(logical) if req.DataID == 0 { req.DataID = runtime.nextDataID() resp.DataID = req.DataID } stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, streamTransportGeneration(logical, transport), serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) if err := runtime.register(scope, stream); err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } if claimed, err := s.claimInboundRecordStream(logical, transport, stream); claimed { if err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) return } if claimed, err := s.claimInboundTransferStream(logical, transport, stream); claimed { if err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) return } handler := runtime.handlerSnapshot() if handler == nil { stream.markReset(errStreamHandlerNotConfigured) resp.Error = errStreamHandlerNotConfigured.Error() replyStreamControlIfNeeded(msg, resp) return } info := StreamAcceptInfo{ ID: stream.ID(), DataID: stream.dataIDSnapshot(), Channel: stream.Channel(), Metadata: stream.Metadata(), LogicalConn: logical, TransportConn: transport, TransportGeneration: stream.TransportGeneration(), Stream: stream, } if err := handler(info); err != nil { stream.markReset(err) resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } resp.Accepted = true resp.DataID = stream.dataIDSnapshot() resp.TransportGeneration = stream.TransportGeneration() replyStreamControlIfNeeded(msg, resp) } func (c *ClientCommon) handleInboundStreamClose(msg *Message) { req, err := decodeStreamCloseRequest(msg) resp := StreamCloseResponse{StreamID: req.StreamID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := c.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } stream, ok := runtime.lookup(clientFileScope(), req.StreamID) if !ok { resp.Error = errStreamNotFound.Error() replyStreamControlIfNeeded(msg, resp) return } if req.Full { stream.markPeerClosed() } else { stream.markRemoteClosed() } resp.Accepted = true replyStreamControlIfNeeded(msg, resp) } func (s *ServerCommon) handleInboundStreamClose(msg *Message) { req, err := decodeStreamCloseRequest(msg) resp := StreamCloseResponse{StreamID: req.StreamID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := s.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } logical := messageLogicalConnSnapshot(msg) scope := serverFileScope(logical) stream, ok := runtime.lookup(scope, req.StreamID) if !ok { resp.Error = errStreamNotFound.Error() replyStreamControlIfNeeded(msg, resp) return } if req.Full { stream.markPeerClosed() } else { stream.markRemoteClosed() } resp.Accepted = true replyStreamControlIfNeeded(msg, resp) } func (c *ClientCommon) handleInboundStreamReset(msg *Message) { req, err := decodeStreamResetRequest(msg) resp := StreamResetResponse{StreamID: req.StreamID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := c.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } stream, ok := runtime.lookup(clientFileScope(), req.StreamID) if !ok && req.DataID != 0 { stream, ok = runtime.lookupByDataID(clientFileScope(), req.DataID) } if !ok { resp.Error = errStreamNotFound.Error() replyStreamControlIfNeeded(msg, resp) return } if resp.StreamID == "" { resp.StreamID = stream.ID() } stream.markReset(streamResetError(streamRemoteResetError(req.Error))) resp.Accepted = true replyStreamControlIfNeeded(msg, resp) } func (s *ServerCommon) handleInboundStreamReset(msg *Message) { req, err := decodeStreamResetRequest(msg) resp := StreamResetResponse{StreamID: req.StreamID} if err != nil { resp.Error = err.Error() replyStreamControlIfNeeded(msg, resp) return } runtime := s.getStreamRuntime() if runtime == nil { resp.Error = errStreamRuntimeNil.Error() replyStreamControlIfNeeded(msg, resp) return } logical := messageLogicalConnSnapshot(msg) scope := serverFileScope(logical) stream, ok := runtime.lookup(scope, req.StreamID) if !ok && req.DataID != 0 { stream, ok = runtime.lookupByDataID(scope, req.DataID) } if !ok { resp.Error = errStreamNotFound.Error() replyStreamControlIfNeeded(msg, resp) return } if resp.StreamID == "" { resp.StreamID = stream.ID() } stream.markReset(streamResetError(streamRemoteResetError(req.Error))) resp.Accepted = true replyStreamControlIfNeeded(msg, resp) } func replyStreamControlIfNeeded(msg *Message, value interface{}) { if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) { return } _ = msg.ReplyObj(value) } func sendStreamOpenClient(ctx context.Context, c Client, req StreamOpenRequest) (StreamOpenResponse, error) { if c == nil { return StreamOpenResponse{}, errStreamClientNil } msg, err := c.SendObjCtx(ctx, StreamOpenSignalKey, req) if err != nil { return StreamOpenResponse{}, err } return decodeStreamOpenResponse(msg) } func sendStreamOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamOpenRequest) (StreamOpenResponse, error) { if s == nil { return StreamOpenResponse{}, errStreamServerNil } if logical == nil { return StreamOpenResponse{}, errStreamLogicalConnNil } msg, err := s.SendObjCtxLogical(ctx, logical, StreamOpenSignalKey, req) if err != nil { return StreamOpenResponse{}, err } return decodeStreamOpenResponse(msg) } func sendStreamOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamOpenRequest) (StreamOpenResponse, error) { if s == nil { return StreamOpenResponse{}, errStreamServerNil } if transport == nil { return StreamOpenResponse{}, errStreamTransportNil } msg, err := s.SendObjCtxTransport(ctx, transport, StreamOpenSignalKey, req) if err != nil { return StreamOpenResponse{}, err } return decodeStreamOpenResponse(msg) } func sendStreamCloseClient(ctx context.Context, c Client, req StreamCloseRequest) (StreamCloseResponse, error) { if c == nil { return StreamCloseResponse{}, errStreamClientNil } msg, err := c.SendObjCtx(ctx, StreamCloseSignalKey, req) if err != nil { return StreamCloseResponse{}, err } return decodeStreamCloseResponse(msg) } func sendStreamCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamCloseRequest) (StreamCloseResponse, error) { if s == nil { return StreamCloseResponse{}, errStreamServerNil } if logical == nil { return StreamCloseResponse{}, errStreamLogicalConnNil } msg, err := s.SendObjCtxLogical(ctx, logical, StreamCloseSignalKey, req) if err != nil { return StreamCloseResponse{}, err } return decodeStreamCloseResponse(msg) } func sendStreamCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamCloseRequest) (StreamCloseResponse, error) { if s == nil { return StreamCloseResponse{}, errStreamServerNil } if transport == nil { return StreamCloseResponse{}, errStreamTransportNil } msg, err := s.SendObjCtxTransport(ctx, transport, StreamCloseSignalKey, req) if err != nil { return StreamCloseResponse{}, err } return decodeStreamCloseResponse(msg) } func sendStreamResetClient(ctx context.Context, c Client, req StreamResetRequest) (StreamResetResponse, error) { if c == nil { return StreamResetResponse{}, errStreamClientNil } msg, err := c.SendObjCtx(ctx, StreamResetSignalKey, req) if err != nil { return StreamResetResponse{}, err } return decodeStreamResetResponse(msg) } func sendStreamResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamResetRequest) (StreamResetResponse, error) { if s == nil { return StreamResetResponse{}, errStreamServerNil } if logical == nil { return StreamResetResponse{}, errStreamLogicalConnNil } msg, err := s.SendObjCtxLogical(ctx, logical, StreamResetSignalKey, req) if err != nil { return StreamResetResponse{}, err } return decodeStreamResetResponse(msg) } func sendStreamResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamResetRequest) (StreamResetResponse, error) { if s == nil { return StreamResetResponse{}, errStreamServerNil } if transport == nil { return StreamResetResponse{}, errStreamTransportNil } msg, err := s.SendObjCtxTransport(ctx, transport, StreamResetSignalKey, req) if err != nil { return StreamResetResponse{}, err } return decodeStreamResetResponse(msg) } func decodeStreamOpenRequest(msg *Message) (StreamOpenRequest, error) { var req StreamOpenRequest if msg == nil { return StreamOpenRequest{}, errStreamIDEmpty } if err := msg.Value.Orm(&req); err != nil { return StreamOpenRequest{}, err } req = normalizeStreamOpenRequest(req) if req.StreamID == "" { return StreamOpenRequest{}, errStreamIDEmpty } return req, nil } func decodeStreamCloseRequest(msg *Message) (StreamCloseRequest, error) { var req StreamCloseRequest if msg == nil { return StreamCloseRequest{}, errStreamIDEmpty } if err := msg.Value.Orm(&req); err != nil { return StreamCloseRequest{}, err } if req.StreamID == "" { return StreamCloseRequest{}, errStreamIDEmpty } return req, nil } func decodeStreamResetRequest(msg *Message) (StreamResetRequest, error) { var req StreamResetRequest if msg == nil { return StreamResetRequest{}, errStreamIDEmpty } if err := msg.Value.Orm(&req); err != nil { return StreamResetRequest{}, err } if req.StreamID == "" { return StreamResetRequest{}, errStreamIDEmpty } return req, nil } func decodeStreamOpenResponse(msg Message) (StreamOpenResponse, error) { var resp StreamOpenResponse if err := msg.Value.Orm(&resp); err != nil { return StreamOpenResponse{}, err } return resp, streamControlResultError("open", resp.Accepted, resp.Error, nil) } func decodeStreamCloseResponse(msg Message) (StreamCloseResponse, error) { var resp StreamCloseResponse if err := msg.Value.Orm(&resp); err != nil { return StreamCloseResponse{}, err } return resp, streamControlResultError("close", resp.Accepted, resp.Error, nil) } func decodeStreamResetResponse(msg Message) (StreamResetResponse, error) { var resp StreamResetResponse if err := msg.Value.Orm(&resp); err != nil { return StreamResetResponse{}, err } return resp, streamControlResultError("reset", resp.Accepted, resp.Error, nil) } func streamControlResultError(op string, accepted bool, message string, callErr error) error { if callErr != nil { return callErr } if message != "" { return streamControlMessageError(message) } if accepted { return nil } if op == "open" { return errStreamRejected } return errors.New("stream " + op + " rejected") } func streamControlMessageError(message string) error { switch message { case errStreamNotFound.Error(): return errStreamNotFound case errStreamAlreadyExists.Error(): return errStreamAlreadyExists case errStreamHandlerNotConfigured.Error(): return errStreamHandlerNotConfigured case errStreamLogicalConnNil.Error(): return errStreamLogicalConnNil case errStreamTransportNil.Error(): return errStreamTransportNil case errStreamRuntimeNil.Error(): return errStreamRuntimeNil case errStreamIDEmpty.Error(): return errStreamIDEmpty default: return errors.New(message) } } func streamTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 { if transport != nil { return transport.TransportGeneration() } if logical != nil { return logical.transportGenerationSnapshot() } return 0 } func streamRemoteResetError(message string) error { if message == "" { return errStreamReset } return errors.New(message) }