package notify import ( "b612.me/stario" "context" "errors" "io" "math" "net" "os" "strings" "testing" "time" ) func TestStreamOpenRoundTripTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{ Channel: StreamDataChannel, Metadata: StreamMetadata{ "name": "demo.bin", }, }) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } var accepted StreamAcceptInfo select { case accepted = <-acceptCh: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for accepted stream") } if accepted.ID != stream.ID() { t.Fatalf("accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID()) } if accepted.Channel != StreamDataChannel { t.Fatalf("accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamDataChannel) } if accepted.Metadata["name"] != "demo.bin" { t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata) } if accepted.LogicalConn == nil { t.Fatal("accepted logical connection should not be nil") } if accepted.TransportConn == nil { t.Fatal("accepted transport connection should not be nil") } clientHandle, ok := stream.(*streamHandle) if !ok { t.Fatalf("stream type = %T, want *streamHandle", stream) } if accepted.DataID == 0 { t.Fatal("accepted stream data id should not be zero") } if got, want := clientHandle.dataIDSnapshot(), accepted.DataID; got != want { t.Fatalf("client stream data id = %d, want %d", got, want) } if _, err := stream.Write([]byte("hello-from-client")); err != nil { t.Fatalf("client stream Write failed: %v", err) } readStreamExactly(t, accepted.Stream, "hello-from-client", 2*time.Second) if _, err := accepted.Stream.Write([]byte("hello-from-server")); err != nil { t.Fatalf("server accepted stream Write failed: %v", err) } readStreamExactly(t, stream, "hello-from-server", 2*time.Second) if err := stream.Close(); err != nil { t.Fatalf("client stream Close failed: %v", err) } waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) if err := accepted.Stream.Close(); err != nil { t.Fatalf("server accepted stream Close failed: %v", err) } waitForStreamContextDone(t, stream.Context(), 2*time.Second) } func TestStreamCloseWriteKeepsReadSideAliveTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) if err := accepted.Stream.CloseWrite(); err != nil { t.Fatalf("server accepted stream CloseWrite failed: %v", err) } waitForStreamReadEOF(t, stream, 2*time.Second) if _, err := stream.Write([]byte("client-after-peer-close")); err != nil { t.Fatalf("client stream Write after peer CloseWrite failed: %v", err) } readStreamExactly(t, accepted.Stream, "client-after-peer-close", 2*time.Second) if err := stream.CloseWrite(); err != nil { t.Fatalf("client stream CloseWrite failed: %v", err) } waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) waitForStreamContextDone(t, stream.Context(), 2*time.Second) } func TestStreamCloseFullStopsPeerWritesTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) if err := accepted.Stream.Close(); err != nil { t.Fatalf("server accepted stream Close failed: %v", err) } waitForStreamReadEOF(t, stream, 2*time.Second) waitForStreamContextDone(t, stream.Context(), 2*time.Second) if _, err := stream.Write([]byte("client-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("client stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe) } } func TestStreamCloseAfterCloseWriteStopsPeerWritesTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) if err := stream.CloseWrite(); err != nil { t.Fatalf("client stream CloseWrite failed: %v", err) } waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) if _, err := accepted.Stream.Write([]byte("server-can-still-reply")); err != nil { t.Fatalf("server accepted stream Write after peer CloseWrite failed: %v", err) } readStreamExactly(t, stream, "server-can-still-reply", 2*time.Second) if err := stream.Close(); err != nil { t.Fatalf("client stream Close after CloseWrite failed: %v", err) } waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) waitForStreamContextDone(t, accepted.Stream.Context(), 2*time.Second) if _, err := accepted.Stream.Write([]byte("server-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("server accepted stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe) } } func TestStreamWritePrefersResetErrorOverContextCanceled(t *testing.T) { wantErr := errors.New("remote stream reset") runtime := newStreamRuntime("stream-reset") stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{ StreamID: "stream-reset-propagation", DataID: 1, }, 0, nil, nil, 0, nil, nil, func(ctx context.Context, s *streamHandle, chunk []byte) error { s.markReset(wantErr) <-ctx.Done() return ctx.Err() }, streamConfig{ChunkSize: 4}) _, err := stream.Write([]byte("abcdefgh")) if !errors.Is(err, wantErr) { t.Fatalf("stream Write error = %v, want %v", err, wantErr) } } func TestStreamWriteWaitingBudgetPrefersClosedPipeOverContextCanceled(t *testing.T) { cfg := streamConfig{ ChunkSize: 4, OutboundWindowBytes: 4, OutboundMaxInFlightChunks: 1, } runtime := newStreamRuntime("stream-budget-close") runtime.applyConfig(cfg) release, err := runtime.acquireOutbound(context.Background(), 4) if err != nil { t.Fatalf("acquireOutbound setup failed: %v", err) } defer release() stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{ StreamID: "stream-budget-close", DataID: 1, }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { return nil }, cfg) errCh := make(chan error, 1) go func() { _, err := stream.Write([]byte("abcd")) errCh <- err }() time.Sleep(20 * time.Millisecond) stream.markPeerClosed() select { case err := <-errCh: if !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("stream Write error = %v, want %v", err, io.ErrClosedPipe) } case <-time.After(time.Second): t.Fatal("stream Write did not return after peer close") } } func TestStreamReadWaitingLocalClosePrefersClosedPipeOverContextCanceled(t *testing.T) { stream := newStreamHandle(context.Background(), nil, "test", StreamOpenRequest{ StreamID: "stream-read-local-close", DataID: 1, }, 0, nil, nil, 0, nil, nil, nil, streamConfig{}) errCh := make(chan error, 1) go func() { buf := make([]byte, 4) _, err := stream.Read(buf) errCh <- err }() time.Sleep(20 * time.Millisecond) if err := stream.Close(); err != nil { t.Fatalf("stream Close failed: %v", err) } select { case err := <-errCh: if !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("stream Read error = %v, want %v", err, io.ErrClosedPipe) } case <-time.After(time.Second): t.Fatal("stream Read did not return after local close") } } func TestStreamOpenRoundTripServerToClientTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) client.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) stream, err := server.OpenStreamLogical(context.Background(), logical, StreamOpenOptions{ Channel: StreamControlChannel, Metadata: StreamMetadata{ "purpose": "server-open", }, }) if err != nil { t.Fatalf("server OpenStreamLogical failed: %v", err) } var accepted StreamAcceptInfo select { case accepted = <-acceptCh: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for client accepted stream") } if accepted.ID != stream.ID() { t.Fatalf("client accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID()) } if accepted.Channel != StreamControlChannel { t.Fatalf("client accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamControlChannel) } if accepted.Metadata["purpose"] != "server-open" { t.Fatalf("client accepted metadata mismatch: %+v", accepted.Metadata) } if accepted.LogicalConn != nil { t.Fatalf("client accepted logical connection should be nil: %+v", accepted.LogicalConn) } serverHandle, ok := stream.(*streamHandle) if !ok { t.Fatalf("stream type = %T, want *streamHandle", stream) } if accepted.DataID == 0 { t.Fatal("client accepted stream data id should not be zero") } if got, want := serverHandle.dataIDSnapshot(), accepted.DataID; got != want { t.Fatalf("server stream data id = %d, want %d", got, want) } if _, err := stream.Write([]byte("server-opened")); err != nil { t.Fatalf("server stream Write failed: %v", err) } readStreamExactly(t, accepted.Stream, "server-opened", 2*time.Second) if _, err := accepted.Stream.Write([]byte("client-accepted")); err != nil { t.Fatalf("client accepted stream Write failed: %v", err) } readStreamExactly(t, stream, "client-accepted", 2*time.Second) if err := stream.Close(); err != nil { t.Fatalf("server stream Close failed: %v", err) } waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) if err := accepted.Stream.Close(); err != nil { t.Fatalf("client accepted stream Close failed: %v", err) } waitForStreamContextDone(t, stream.Context(), 2*time.Second) } func TestStreamResetRoundTripTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) resetCause := errors.New("stream-reset-by-server") if err := accepted.Stream.Reset(resetCause); err != nil { t.Fatalf("server accepted stream Reset failed: %v", err) } readErr := readStreamError(t, stream, 2*time.Second) if !strings.Contains(readErr.Error(), resetCause.Error()) { t.Fatalf("stream Read reset error mismatch: got %v want %q", readErr, resetCause.Error()) } waitForStreamContextDone(t, stream.Context(), 2*time.Second) } func TestStreamSetReadDeadlineUnblocksPendingRead(t *testing.T) { runtime := newStreamRuntime("read-deadline") stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "read-deadline-stream", Channel: StreamDataChannel, ReadTimeout: time.Second, }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stream failed: %v", err) } errCh := make(chan error, 1) go func() { buf := make([]byte, 1) _, err := stream.Read(buf) errCh <- err }() time.Sleep(20 * time.Millisecond) if err := stream.SetReadDeadline(time.Now().Add(40 * time.Millisecond)); err != nil { t.Fatalf("SetReadDeadline failed: %v", err) } select { case err := <-errCh: if !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("stream Read error = %v, want %v", err, os.ErrDeadlineExceeded) } case <-time.After(time.Second): t.Fatal("timed out waiting for read deadline") } } func TestStreamSetWriteDeadlineUnblocksBlockedWrite(t *testing.T) { runtime := newStreamRuntime("write-deadline") runtime.applyConfig(streamConfig{ ChunkSize: 4, InboundQueueLimit: defaultStreamInboundQueueLimit, InboundBufferedBytesLimit: defaultStreamInboundBufferedBytesLimit, OutboundWindowBytes: 4, OutboundMaxInFlightChunks: 1, }) holdCtx, holdCancel := context.WithCancel(context.Background()) defer holdCancel() release, err := runtime.acquireOutbound(holdCtx, 4) if err != nil { t.Fatalf("acquireOutbound failed: %v", err) } defer release() stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "write-deadline-stream", Channel: StreamDataChannel, WriteTimeout: time.Second, ReadTimeout: time.Second, }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { return nil }, runtime.configSnapshot()) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stream failed: %v", err) } errCh := make(chan error, 1) go func() { _, err := stream.Write([]byte("abcd")) errCh <- err }() time.Sleep(20 * time.Millisecond) if err := stream.SetWriteDeadline(time.Now().Add(40 * time.Millisecond)); err != nil { t.Fatalf("SetWriteDeadline failed: %v", err) } select { case err := <-errCh: if !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("stream Write error = %v, want %v", err, os.ErrDeadlineExceeded) } case <-time.After(time.Second): t.Fatal("timed out waiting for write deadline") } } func TestStreamImplementsNetConnTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) if err != nil { t.Fatalf("client OpenStream failed: %v", err) } accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) var clientConn net.Conn = stream var serverConn net.Conn = accepted.Stream if clientConn.LocalAddr() == nil || clientConn.RemoteAddr() == nil { t.Fatalf("client stream net.Conn addrs missing: local=%v remote=%v", clientConn.LocalAddr(), clientConn.RemoteAddr()) } if serverConn.LocalAddr() == nil || serverConn.RemoteAddr() == nil { t.Fatalf("server stream net.Conn addrs missing: local=%v remote=%v", serverConn.LocalAddr(), serverConn.RemoteAddr()) } if err := clientConn.SetDeadline(time.Now().Add(time.Second)); err != nil { t.Fatalf("client stream SetDeadline failed: %v", err) } if err := serverConn.SetDeadline(time.Now().Add(time.Second)); err != nil { t.Fatalf("server stream SetDeadline failed: %v", err) } if _, err := clientConn.Write([]byte("from-net-conn-client")); err != nil { t.Fatalf("client net.Conn Write failed: %v", err) } readStreamExactly(t, accepted.Stream, "from-net-conn-client", 2*time.Second) if _, err := serverConn.Write([]byte("from-net-conn-server")); err != nil { t.Fatalf("server net.Conn Write failed: %v", err) } readStreamExactly(t, stream, "from-net-conn-server", 2*time.Second) } func TestClientDispatchStreamEnvelopeRejectsStaleSessionEpoch(t *testing.T) { client := NewClient().(*ClientCommon) runtime := client.getStreamRuntime() staleEpoch := client.beginClientSessionEpoch() _ = client.beginClientSessionEpoch() stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "client-stale", Channel: StreamDataChannel, ReadTimeout: 20 * time.Millisecond, }, staleEpoch, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stale client stream failed: %v", err) } client.dispatchStreamEnvelope(newStreamDataEnvelope("client-stale", []byte("payload"))) readErr := readStreamError(t, stream, time.Second) if !errors.Is(readErr, errTransportDetached) { t.Fatalf("stale client stream read error mismatch: got %v want %v", readErr, errTransportDetached) } waitForStreamContextDone(t, stream.Context(), time.Second) if _, ok := runtime.lookup(clientFileScope(), "client-stale"); ok { t.Fatal("stale client stream should be removed from runtime") } } func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatch(t *testing.T) { server := NewServer().(*ServerCommon) runtime := server.getStreamRuntime() clientConn := &ClientConn{ ClientID: "server-stale-peer", server: server, } logical := logicalConnFromClient(clientConn) scope := serverFileScope(logical) stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{ StreamID: "server-stale", Channel: StreamDataChannel, ReadTimeout: 20 * time.Millisecond, }, 0, logical, &TransportConn{ logical: logical, generation: 1, remoteAddr: streamTestAddr("current"), attached: true, }, 1, nil, nil, nil, defaultStreamConfig()) if err := runtime.register(scope, stream); err != nil { t.Fatalf("register server stream failed: %v", err) } server.dispatchStreamEnvelope(logical, &TransportConn{ logical: logical, generation: 2, remoteAddr: streamTestAddr("stale"), attached: true, }, nil, newStreamDataEnvelope("server-stale", []byte("stale-payload"))) readErr := readStreamError(t, stream, time.Second) if !errors.Is(readErr, os.ErrDeadlineExceeded) { t.Fatalf("server stale generation read error mismatch: got %v want %v", readErr, os.ErrDeadlineExceeded) } server.dispatchStreamEnvelope(logical, &TransportConn{ logical: logical, generation: 1, remoteAddr: streamTestAddr("current"), attached: true, }, nil, newStreamDataEnvelope("server-stale", []byte("good-payload"))) readStreamExactly(t, stream, "good-payload", time.Second) } func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatchWritesResetViaInboundConn(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) defer runtimeCancel() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: runtimeCtx, stopFn: runtimeCancel, queue: stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32), }) left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("server-stream-reset", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } transport := logical.CurrentTransportConn() if transport == nil { t.Fatal("current transport should exist") } runtime := server.getStreamRuntime() stream := newStreamHandle(context.Background(), runtime, serverFileScope(logical), StreamOpenRequest{ StreamID: "server-stale-reset", Channel: StreamDataChannel, }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) if err := runtime.register(serverFileScope(logical), stream); err != nil { t.Fatalf("register server stream failed: %v", err) } staleTransport := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration()+1, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } done := make(chan struct{}) go func() { server.dispatchStreamEnvelope(logical, staleTransport, left, newStreamDataEnvelope("server-stale-reset", []byte("stale-payload"))) close(done) }() env := readServerEnvelopeFromConn(t, server, logical, right, time.Second) select { case <-done: case <-time.After(time.Second): t.Fatal("timed out waiting for stream dispatch to finish") } if env.Kind != EnvelopeSignal { t.Fatalf("reset envelope kind = %v, want %v", env.Kind, EnvelopeSignal) } transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) if err != nil { t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) } if transfer.Key != StreamResetSignalKey { t.Fatalf("reset transfer key = %q, want %q", transfer.Key, StreamResetSignalKey) } if transfer.Type != MSG_ASYNC { t.Fatalf("reset transfer type = %v, want %v", transfer.Type, MSG_ASYNC) } var req StreamResetRequest if err := transfer.Value.Orm(&req); err != nil { t.Fatalf("decode reset request failed: %v", err) } if req.StreamID != "server-stale-reset" { t.Fatalf("reset stream id = %q, want %q", req.StreamID, "server-stale-reset") } if !strings.HasPrefix(req.Error, errTransportDetached.Error()) { t.Fatalf("reset error = %q, want prefix %q", req.Error, errTransportDetached.Error()) } } func TestStreamBackpressureOverflowResetsStreamAndRemovesRuntimeEntry(t *testing.T) { runtime := newStreamRuntime("overflow") runtime.applyConfig(streamConfig{ ChunkSize: 4, InboundQueueLimit: 1, InboundBufferedBytesLimit: 4, }) stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "overflow-stream", Channel: StreamDataChannel, ReadTimeout: 20 * time.Millisecond, }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stream failed: %v", err) } if err := stream.pushChunk([]byte("abcd")); err != nil { t.Fatalf("first pushChunk failed: %v", err) } if err := stream.pushChunk([]byte("ef")); !errors.Is(err, errStreamBackpressureExceeded) { t.Fatalf("overflow pushChunk error = %v, want %v", err, errStreamBackpressureExceeded) } readErr := readStreamError(t, stream, time.Second) if !errors.Is(readErr, errStreamBackpressureExceeded) { t.Fatalf("stream read error = %v, want %v", readErr, errStreamBackpressureExceeded) } if _, ok := runtime.lookup(clientFileScope(), "overflow-stream"); ok { t.Fatal("overflowed stream should be removed from runtime") } } func TestServerDetachLogicalSessionTransportResetsScopedStreams(t *testing.T) { server := NewServer().(*ServerCommon) runtime := server.getStreamRuntime() client := &ClientConn{ ClientID: "detached-peer", server: server, } logical := logicalConnFromClient(client) scope := serverFileScope(logical) stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{ StreamID: "detach-stream", Channel: StreamDataChannel, ReadTimeout: 20 * time.Millisecond, }, 0, logical, &TransportConn{ logical: logical, generation: 1, remoteAddr: streamTestAddr("detach"), attached: true, }, 1, nil, nil, nil, defaultStreamConfig()) if err := runtime.register(scope, stream); err != nil { t.Fatalf("register stream failed: %v", err) } left, right := net.Pipe() defer left.Close() defer right.Close() logical.startSession(left, nil, nil) server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) readErr := readStreamError(t, stream, time.Second) if !errors.Is(readErr, errTransportDetached) { t.Fatalf("detached stream read error = %v, want %v", readErr, errTransportDetached) } if _, ok := runtime.lookup(scope, "detach-stream"); ok { t.Fatal("detached stream should be removed from runtime") } } func TestGetStreamSnapshotsIncludesBufferedState(t *testing.T) { client := NewClient().(*ClientCommon) runtime := client.getStreamRuntime() stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "snapshot-stream", Channel: StreamControlChannel, ReadTimeout: time.Second, WriteTimeout: 2 * time.Second, Metadata: StreamMetadata{ "name": "snapshot-demo", }, }, 7, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) stream.setClientSnapshotOwner(client) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register snapshot stream failed: %v", err) } if err := stream.pushChunk([]byte("hello")); err != nil { t.Fatalf("pushChunk failed: %v", err) } snapshots, err := GetClientStreamSnapshots(client) if err != nil { t.Fatalf("GetClientStreamSnapshots failed: %v", err) } if got, want := len(snapshots), 1; got != want { t.Fatalf("stream snapshot count = %d, want %d", got, want) } snapshot := snapshots[0] if got, want := snapshot.ID, "snapshot-stream"; got != want { t.Fatalf("snapshot ID = %q, want %q", got, want) } if got, want := snapshot.Scope, clientFileScope(); got != want { t.Fatalf("snapshot Scope = %q, want %q", got, want) } if got, want := snapshot.Channel, StreamControlChannel; got != want { t.Fatalf("snapshot Channel = %q, want %q", got, want) } if got, want := snapshot.SessionEpoch, uint64(7); got != want { t.Fatalf("snapshot SessionEpoch = %d, want %d", got, want) } if got, want := snapshot.BufferedChunks, 1; got != want { t.Fatalf("snapshot BufferedChunks = %d, want %d", got, want) } if got, want := snapshot.BufferedBytes, 5; got != want { t.Fatalf("snapshot BufferedBytes = %d, want %d", got, want) } if snapshot.LocalReadClosed { t.Fatal("snapshot LocalReadClosed should be false") } if snapshot.PeerReadClosed { t.Fatal("snapshot PeerReadClosed should be false") } if got := snapshot.Metadata["name"]; got != "snapshot-demo" { t.Fatalf("snapshot metadata mismatch: %+v", snapshot.Metadata) } if got, want := snapshot.ReadTimeout, time.Second; got != want { t.Fatalf("snapshot ReadTimeout = %v, want %v", got, want) } if got, want := snapshot.WriteTimeout, 2*time.Second; got != want { t.Fatalf("snapshot WriteTimeout = %v, want %v", got, want) } if got, want := snapshot.BindingOwner, "client-session"; got != want { t.Fatalf("snapshot BindingOwner = %q, want %q", got, want) } } func TestGetStreamSnapshotsIncludesIOObservability(t *testing.T) { runtime := newStreamRuntime("snapshot-observe") stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "snapshot-observe-stream", Channel: StreamDataChannel, }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { return nil }, runtime.configSnapshot()) stream.setAddrSnapshot(streamTestAddr("local-addr"), streamTestAddr("remote-addr")) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stream failed: %v", err) } if err := stream.pushChunk([]byte("hello")); err != nil { t.Fatalf("pushChunk failed: %v", err) } buf := make([]byte, 2) if _, err := stream.Read(buf); err != nil { t.Fatalf("stream Read failed: %v", err) } if _, err := stream.Write([]byte("world")); err != nil { t.Fatalf("stream Write failed: %v", err) } readDeadline := time.Now().Add(time.Minute).Round(0) writeDeadline := time.Now().Add(2 * time.Minute).Round(0) if err := stream.SetReadDeadline(readDeadline); err != nil { t.Fatalf("SetReadDeadline failed: %v", err) } if err := stream.SetWriteDeadline(writeDeadline); err != nil { t.Fatalf("SetWriteDeadline failed: %v", err) } snapshots := runtime.snapshots() if got, want := len(snapshots), 1; got != want { t.Fatalf("snapshot count = %d, want %d", got, want) } snapshot := snapshots[0] if got, want := snapshot.LocalAddress, "local-addr"; got != want { t.Fatalf("snapshot LocalAddress = %q, want %q", got, want) } if got, want := snapshot.RemoteAddress, "remote-addr"; got != want { t.Fatalf("snapshot RemoteAddress = %q, want %q", got, want) } if got, want := snapshot.BytesRead, int64(2); got != want { t.Fatalf("snapshot BytesRead = %d, want %d", got, want) } if got, want := snapshot.BytesWritten, int64(5); got != want { t.Fatalf("snapshot BytesWritten = %d, want %d", got, want) } if got, want := snapshot.ReadCalls, int64(1); got != want { t.Fatalf("snapshot ReadCalls = %d, want %d", got, want) } if got, want := snapshot.WriteCalls, int64(1); got != want { t.Fatalf("snapshot WriteCalls = %d, want %d", got, want) } if snapshot.OpenedAt.IsZero() { t.Fatal("snapshot OpenedAt should not be zero") } if snapshot.LastReadAt.IsZero() { t.Fatal("snapshot LastReadAt should not be zero") } if snapshot.LastWriteAt.IsZero() { t.Fatal("snapshot LastWriteAt should not be zero") } if got, want := snapshot.ReadDeadline, readDeadline; !got.Equal(want) { t.Fatalf("snapshot ReadDeadline = %v, want %v", got, want) } if got, want := snapshot.WriteDeadline, writeDeadline; !got.Equal(want) { t.Fatalf("snapshot WriteDeadline = %v, want %v", got, want) } } func TestStreamSnapshotIncludesDetachedBindingDiagnostics(t *testing.T) { server := NewServer().(*ServerCommon) left, right := net.Pipe() defer right.Close() logical := server.bootstrapAcceptedLogical("stream-snapshot-detach", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } transport := logical.CurrentTransportConn() if transport == nil { t.Fatal("CurrentTransportConn should return active transport") } stream := newStreamHandle(context.Background(), newStreamRuntime("snapshot-detach"), serverFileScope(logical), StreamOpenRequest{ StreamID: "stream-snapshot-detach", }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) snapshot := stream.snapshot() if got, want := snapshot.BindingOwner, "server-transport"; got != want { t.Fatalf("snapshot BindingOwner = %q, want %q", got, want) } if snapshot.BindingCurrent { t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot) } if snapshot.TransportAttached { t.Fatalf("snapshot TransportAttached should be false after detach: %+v", snapshot) } if snapshot.TransportCurrent { t.Fatalf("snapshot TransportCurrent should be false after detach: %+v", snapshot) } if got, want := snapshot.TransportDetachReason, "read error"; got != want { t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want) } if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want { t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want) } if got, want := snapshot.TransportDetachError, "boom"; got != want { t.Fatalf("snapshot TransportDetachError = %q, want %q", got, want) } } func waitForStreamReadEOF(t *testing.T, stream Stream, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) buf := make([]byte, 1) for time.Now().Before(deadline) { _, err := stream.Read(buf) if errors.Is(err, io.EOF) { return } if err != nil && !errors.Is(err, errStreamDataPathNotReady) { t.Fatalf("stream Read returned unexpected error: %v", err) } time.Sleep(10 * time.Millisecond) } t.Fatal("timed out waiting for stream EOF") } func waitForStreamContextDone(t *testing.T, ctx context.Context, timeout time.Duration) { t.Helper() select { case <-ctx.Done(): case <-time.After(timeout): t.Fatal("timed out waiting for stream context done") } } func waitAcceptedStream(t *testing.T, ch <-chan StreamAcceptInfo, timeout time.Duration) StreamAcceptInfo { t.Helper() select { case info := <-ch: return info case <-time.After(timeout): t.Fatal("timed out waiting for accepted stream") return StreamAcceptInfo{} } } func readStreamExactly(t *testing.T, stream Stream, want string, timeout time.Duration) { t.Helper() errCh := make(chan error, 1) go func() { buf := make([]byte, len(want)) _, err := io.ReadFull(stream, buf) if err != nil { errCh <- err return } if got := string(buf); got != want { errCh <- errors.New("stream payload mismatch: got " + got + " want " + want) return } errCh <- nil }() select { case err := <-errCh: if err != nil { t.Fatal(err) } case <-time.After(timeout): t.Fatal("timed out waiting for stream payload") } } func readStreamError(t *testing.T, stream Stream, timeout time.Duration) error { t.Helper() errCh := make(chan error, 1) go func() { buf := make([]byte, 1) _, err := stream.Read(buf) errCh <- err }() select { case err := <-errCh: if err == nil { t.Fatal("expected stream read error, got nil") } return err case <-time.After(timeout): t.Fatal("timed out waiting for stream read error") return nil } } func readServerEnvelopeFromConn(t *testing.T, server *ServerCommon, logical *LogicalConn, conn net.Conn, timeout time.Duration) Envelope { t.Helper() queue := stario.NewQueue() deadline := time.Now().Add(timeout) buf := make([]byte, 4096) for time.Now().Before(deadline) { if err := conn.SetReadDeadline(deadline); err != nil { t.Fatalf("SetReadDeadline failed: %v", err) } n, err := conn.Read(buf) if n > 0 { if parseErr := queue.ParseMessage(buf[:n], "stream-test"); parseErr != nil { t.Fatalf("ParseMessage failed: %v", parseErr) } select { case msg := <-queue.RestoreChan(): env, decErr := server.decodeEnvelopeLogical(logical, msg.Msg) if decErr != nil { t.Fatalf("decodeEnvelopeLogical failed: %v", decErr) } return env default: } } if err == nil { continue } if errors.Is(err, os.ErrDeadlineExceeded) { break } if netErr, ok := err.(net.Error); ok && netErr.Timeout() { break } t.Fatalf("conn Read failed: %v", err) } t.Fatal("timed out waiting for server envelope") return Envelope{} } type streamTestAddr string func (a streamTestAddr) Network() string { return "stream-test" } func (a streamTestAddr) String() string { return string(a) } var _ net.Addr = streamTestAddr("")