notify/stream_test.go
starainrt 09d972c7b7
feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
2026-04-15 15:24:36 +08:00

1180 lines
38 KiB
Go

package notify
import (
"b612.me/stario"
"context"
"errors"
"io"
"math"
"net"
"os"
"strings"
"testing"
"time"
)
func TestStreamOpenRoundTripTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{
Channel: StreamDataChannel,
Metadata: StreamMetadata{
"name": "demo.bin",
},
})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
var accepted StreamAcceptInfo
select {
case accepted = <-acceptCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for accepted stream")
}
if accepted.ID != stream.ID() {
t.Fatalf("accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID())
}
if accepted.Channel != StreamDataChannel {
t.Fatalf("accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamDataChannel)
}
if accepted.Metadata["name"] != "demo.bin" {
t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata)
}
if accepted.LogicalConn == nil {
t.Fatal("accepted logical connection should not be nil")
}
if accepted.TransportConn == nil {
t.Fatal("accepted transport connection should not be nil")
}
clientHandle, ok := stream.(*streamHandle)
if !ok {
t.Fatalf("stream type = %T, want *streamHandle", stream)
}
if accepted.DataID == 0 {
t.Fatal("accepted stream data id should not be zero")
}
if got, want := clientHandle.dataIDSnapshot(), accepted.DataID; got != want {
t.Fatalf("client stream data id = %d, want %d", got, want)
}
if _, err := stream.Write([]byte("hello-from-client")); err != nil {
t.Fatalf("client stream Write failed: %v", err)
}
readStreamExactly(t, accepted.Stream, "hello-from-client", 2*time.Second)
if _, err := accepted.Stream.Write([]byte("hello-from-server")); err != nil {
t.Fatalf("server accepted stream Write failed: %v", err)
}
readStreamExactly(t, stream, "hello-from-server", 2*time.Second)
if err := stream.Close(); err != nil {
t.Fatalf("client stream Close failed: %v", err)
}
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
if err := accepted.Stream.Close(); err != nil {
t.Fatalf("server accepted stream Close failed: %v", err)
}
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
}
func TestStreamCloseWriteKeepsReadSideAliveTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
if err := accepted.Stream.CloseWrite(); err != nil {
t.Fatalf("server accepted stream CloseWrite failed: %v", err)
}
waitForStreamReadEOF(t, stream, 2*time.Second)
if _, err := stream.Write([]byte("client-after-peer-close")); err != nil {
t.Fatalf("client stream Write after peer CloseWrite failed: %v", err)
}
readStreamExactly(t, accepted.Stream, "client-after-peer-close", 2*time.Second)
if err := stream.CloseWrite(); err != nil {
t.Fatalf("client stream CloseWrite failed: %v", err)
}
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
}
func TestStreamCloseFullStopsPeerWritesTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
if err := accepted.Stream.Close(); err != nil {
t.Fatalf("server accepted stream Close failed: %v", err)
}
waitForStreamReadEOF(t, stream, 2*time.Second)
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
if _, err := stream.Write([]byte("client-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("client stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestStreamCloseAfterCloseWriteStopsPeerWritesTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
if err := stream.CloseWrite(); err != nil {
t.Fatalf("client stream CloseWrite failed: %v", err)
}
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
if _, err := accepted.Stream.Write([]byte("server-can-still-reply")); err != nil {
t.Fatalf("server accepted stream Write after peer CloseWrite failed: %v", err)
}
readStreamExactly(t, stream, "server-can-still-reply", 2*time.Second)
if err := stream.Close(); err != nil {
t.Fatalf("client stream Close after CloseWrite failed: %v", err)
}
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
waitForStreamContextDone(t, accepted.Stream.Context(), 2*time.Second)
if _, err := accepted.Stream.Write([]byte("server-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("server accepted stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestStreamWritePrefersResetErrorOverContextCanceled(t *testing.T) {
wantErr := errors.New("remote stream reset")
runtime := newStreamRuntime("stream-reset")
stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{
StreamID: "stream-reset-propagation",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, func(ctx context.Context, s *streamHandle, chunk []byte) error {
s.markReset(wantErr)
<-ctx.Done()
return ctx.Err()
}, streamConfig{ChunkSize: 4})
_, err := stream.Write([]byte("abcdefgh"))
if !errors.Is(err, wantErr) {
t.Fatalf("stream Write error = %v, want %v", err, wantErr)
}
}
func TestStreamWriteWaitingBudgetPrefersClosedPipeOverContextCanceled(t *testing.T) {
cfg := streamConfig{
ChunkSize: 4,
OutboundWindowBytes: 4,
OutboundMaxInFlightChunks: 1,
}
runtime := newStreamRuntime("stream-budget-close")
runtime.applyConfig(cfg)
release, err := runtime.acquireOutbound(context.Background(), 4)
if err != nil {
t.Fatalf("acquireOutbound setup failed: %v", err)
}
defer release()
stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{
StreamID: "stream-budget-close",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
return nil
}, cfg)
errCh := make(chan error, 1)
go func() {
_, err := stream.Write([]byte("abcd"))
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
stream.markPeerClosed()
select {
case err := <-errCh:
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("stream Write error = %v, want %v", err, io.ErrClosedPipe)
}
case <-time.After(time.Second):
t.Fatal("stream Write did not return after peer close")
}
}
func TestStreamReadWaitingLocalClosePrefersClosedPipeOverContextCanceled(t *testing.T) {
stream := newStreamHandle(context.Background(), nil, "test", StreamOpenRequest{
StreamID: "stream-read-local-close",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
errCh := make(chan error, 1)
go func() {
buf := make([]byte, 4)
_, err := stream.Read(buf)
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
if err := stream.Close(); err != nil {
t.Fatalf("stream Close failed: %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("stream Read error = %v, want %v", err, io.ErrClosedPipe)
}
case <-time.After(time.Second):
t.Fatal("stream Read did not return after local close")
}
}
func TestStreamOpenRoundTripServerToClientTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
client.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
stream, err := server.OpenStreamLogical(context.Background(), logical, StreamOpenOptions{
Channel: StreamControlChannel,
Metadata: StreamMetadata{
"purpose": "server-open",
},
})
if err != nil {
t.Fatalf("server OpenStreamLogical failed: %v", err)
}
var accepted StreamAcceptInfo
select {
case accepted = <-acceptCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for client accepted stream")
}
if accepted.ID != stream.ID() {
t.Fatalf("client accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID())
}
if accepted.Channel != StreamControlChannel {
t.Fatalf("client accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamControlChannel)
}
if accepted.Metadata["purpose"] != "server-open" {
t.Fatalf("client accepted metadata mismatch: %+v", accepted.Metadata)
}
if accepted.LogicalConn != nil {
t.Fatalf("client accepted logical connection should be nil: %+v", accepted.LogicalConn)
}
serverHandle, ok := stream.(*streamHandle)
if !ok {
t.Fatalf("stream type = %T, want *streamHandle", stream)
}
if accepted.DataID == 0 {
t.Fatal("client accepted stream data id should not be zero")
}
if got, want := serverHandle.dataIDSnapshot(), accepted.DataID; got != want {
t.Fatalf("server stream data id = %d, want %d", got, want)
}
if _, err := stream.Write([]byte("server-opened")); err != nil {
t.Fatalf("server stream Write failed: %v", err)
}
readStreamExactly(t, accepted.Stream, "server-opened", 2*time.Second)
if _, err := accepted.Stream.Write([]byte("client-accepted")); err != nil {
t.Fatalf("client accepted stream Write failed: %v", err)
}
readStreamExactly(t, stream, "client-accepted", 2*time.Second)
if err := stream.Close(); err != nil {
t.Fatalf("server stream Close failed: %v", err)
}
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
if err := accepted.Stream.Close(); err != nil {
t.Fatalf("client accepted stream Close failed: %v", err)
}
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
}
func TestStreamResetRoundTripTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
resetCause := errors.New("stream-reset-by-server")
if err := accepted.Stream.Reset(resetCause); err != nil {
t.Fatalf("server accepted stream Reset failed: %v", err)
}
readErr := readStreamError(t, stream, 2*time.Second)
if !strings.Contains(readErr.Error(), resetCause.Error()) {
t.Fatalf("stream Read reset error mismatch: got %v want %q", readErr, resetCause.Error())
}
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
}
func TestStreamSetReadDeadlineUnblocksPendingRead(t *testing.T) {
runtime := newStreamRuntime("read-deadline")
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "read-deadline-stream",
Channel: StreamDataChannel,
ReadTimeout: time.Second,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
errCh := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, err := stream.Read(buf)
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
if err := stream.SetReadDeadline(time.Now().Add(40 * time.Millisecond)); err != nil {
t.Fatalf("SetReadDeadline failed: %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("stream Read error = %v, want %v", err, os.ErrDeadlineExceeded)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for read deadline")
}
}
func TestStreamSetWriteDeadlineUnblocksBlockedWrite(t *testing.T) {
runtime := newStreamRuntime("write-deadline")
runtime.applyConfig(streamConfig{
ChunkSize: 4,
InboundQueueLimit: defaultStreamInboundQueueLimit,
InboundBufferedBytesLimit: defaultStreamInboundBufferedBytesLimit,
OutboundWindowBytes: 4,
OutboundMaxInFlightChunks: 1,
})
holdCtx, holdCancel := context.WithCancel(context.Background())
defer holdCancel()
release, err := runtime.acquireOutbound(holdCtx, 4)
if err != nil {
t.Fatalf("acquireOutbound failed: %v", err)
}
defer release()
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "write-deadline-stream",
Channel: StreamDataChannel,
WriteTimeout: time.Second,
ReadTimeout: time.Second,
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
return nil
}, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
errCh := make(chan error, 1)
go func() {
_, err := stream.Write([]byte("abcd"))
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
if err := stream.SetWriteDeadline(time.Now().Add(40 * time.Millisecond)); err != nil {
t.Fatalf("SetWriteDeadline failed: %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("stream Write error = %v, want %v", err, os.ErrDeadlineExceeded)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for write deadline")
}
}
func TestStreamImplementsNetConnTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
var clientConn net.Conn = stream
var serverConn net.Conn = accepted.Stream
if clientConn.LocalAddr() == nil || clientConn.RemoteAddr() == nil {
t.Fatalf("client stream net.Conn addrs missing: local=%v remote=%v", clientConn.LocalAddr(), clientConn.RemoteAddr())
}
if serverConn.LocalAddr() == nil || serverConn.RemoteAddr() == nil {
t.Fatalf("server stream net.Conn addrs missing: local=%v remote=%v", serverConn.LocalAddr(), serverConn.RemoteAddr())
}
if err := clientConn.SetDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("client stream SetDeadline failed: %v", err)
}
if err := serverConn.SetDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("server stream SetDeadline failed: %v", err)
}
if _, err := clientConn.Write([]byte("from-net-conn-client")); err != nil {
t.Fatalf("client net.Conn Write failed: %v", err)
}
readStreamExactly(t, accepted.Stream, "from-net-conn-client", 2*time.Second)
if _, err := serverConn.Write([]byte("from-net-conn-server")); err != nil {
t.Fatalf("server net.Conn Write failed: %v", err)
}
readStreamExactly(t, stream, "from-net-conn-server", 2*time.Second)
}
func TestClientDispatchStreamEnvelopeRejectsStaleSessionEpoch(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getStreamRuntime()
staleEpoch := client.beginClientSessionEpoch()
_ = client.beginClientSessionEpoch()
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "client-stale",
Channel: StreamDataChannel,
ReadTimeout: 20 * time.Millisecond,
}, staleEpoch, nil, nil, 0, nil, nil, nil, defaultStreamConfig())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stale client stream failed: %v", err)
}
client.dispatchStreamEnvelope(newStreamDataEnvelope("client-stale", []byte("payload")))
readErr := readStreamError(t, stream, time.Second)
if !errors.Is(readErr, errTransportDetached) {
t.Fatalf("stale client stream read error mismatch: got %v want %v", readErr, errTransportDetached)
}
waitForStreamContextDone(t, stream.Context(), time.Second)
if _, ok := runtime.lookup(clientFileScope(), "client-stale"); ok {
t.Fatal("stale client stream should be removed from runtime")
}
}
func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatch(t *testing.T) {
server := NewServer().(*ServerCommon)
runtime := server.getStreamRuntime()
clientConn := &ClientConn{
ClientID: "server-stale-peer",
server: server,
}
logical := logicalConnFromClient(clientConn)
scope := serverFileScope(logical)
stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{
StreamID: "server-stale",
Channel: StreamDataChannel,
ReadTimeout: 20 * time.Millisecond,
}, 0, logical, &TransportConn{
logical: logical,
generation: 1,
remoteAddr: streamTestAddr("current"),
attached: true,
}, 1, nil, nil, nil, defaultStreamConfig())
if err := runtime.register(scope, stream); err != nil {
t.Fatalf("register server stream failed: %v", err)
}
server.dispatchStreamEnvelope(logical, &TransportConn{
logical: logical,
generation: 2,
remoteAddr: streamTestAddr("stale"),
attached: true,
}, nil, newStreamDataEnvelope("server-stale", []byte("stale-payload")))
readErr := readStreamError(t, stream, time.Second)
if !errors.Is(readErr, os.ErrDeadlineExceeded) {
t.Fatalf("server stale generation read error mismatch: got %v want %v", readErr, os.ErrDeadlineExceeded)
}
server.dispatchStreamEnvelope(logical, &TransportConn{
logical: logical,
generation: 1,
remoteAddr: streamTestAddr("current"),
attached: true,
}, nil, newStreamDataEnvelope("server-stale", []byte("good-payload")))
readStreamExactly(t, stream, "good-payload", time.Second)
}
func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatchWritesResetViaInboundConn(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
defer runtimeCancel()
server.setServerSessionRuntime(&serverSessionRuntime{
stopCtx: runtimeCtx,
stopFn: runtimeCancel,
queue: stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32),
})
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical := server.bootstrapAcceptedLogical("server-stream-reset", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
transport := logical.CurrentTransportConn()
if transport == nil {
t.Fatal("current transport should exist")
}
runtime := server.getStreamRuntime()
stream := newStreamHandle(context.Background(), runtime, serverFileScope(logical), StreamOpenRequest{
StreamID: "server-stale-reset",
Channel: StreamDataChannel,
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
if err := runtime.register(serverFileScope(logical), stream); err != nil {
t.Fatalf("register server stream failed: %v", err)
}
staleTransport := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration()+1, true)
if staleTransport == nil {
t.Fatal("stale transport snapshot should exist")
}
done := make(chan struct{})
go func() {
server.dispatchStreamEnvelope(logical, staleTransport, left, newStreamDataEnvelope("server-stale-reset", []byte("stale-payload")))
close(done)
}()
env := readServerEnvelopeFromConn(t, server, logical, right, time.Second)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timed out waiting for stream dispatch to finish")
}
if env.Kind != EnvelopeSignal {
t.Fatalf("reset envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
}
transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
if err != nil {
t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err)
}
if transfer.Key != StreamResetSignalKey {
t.Fatalf("reset transfer key = %q, want %q", transfer.Key, StreamResetSignalKey)
}
if transfer.Type != MSG_ASYNC {
t.Fatalf("reset transfer type = %v, want %v", transfer.Type, MSG_ASYNC)
}
var req StreamResetRequest
if err := transfer.Value.Orm(&req); err != nil {
t.Fatalf("decode reset request failed: %v", err)
}
if req.StreamID != "server-stale-reset" {
t.Fatalf("reset stream id = %q, want %q", req.StreamID, "server-stale-reset")
}
if !strings.HasPrefix(req.Error, errTransportDetached.Error()) {
t.Fatalf("reset error = %q, want prefix %q", req.Error, errTransportDetached.Error())
}
}
func TestStreamBackpressureOverflowResetsStreamAndRemovesRuntimeEntry(t *testing.T) {
runtime := newStreamRuntime("overflow")
runtime.applyConfig(streamConfig{
ChunkSize: 4,
InboundQueueLimit: 1,
InboundBufferedBytesLimit: 4,
})
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "overflow-stream",
Channel: StreamDataChannel,
ReadTimeout: 20 * time.Millisecond,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
if err := stream.pushChunk([]byte("abcd")); err != nil {
t.Fatalf("first pushChunk failed: %v", err)
}
if err := stream.pushChunk([]byte("ef")); !errors.Is(err, errStreamBackpressureExceeded) {
t.Fatalf("overflow pushChunk error = %v, want %v", err, errStreamBackpressureExceeded)
}
readErr := readStreamError(t, stream, time.Second)
if !errors.Is(readErr, errStreamBackpressureExceeded) {
t.Fatalf("stream read error = %v, want %v", readErr, errStreamBackpressureExceeded)
}
if _, ok := runtime.lookup(clientFileScope(), "overflow-stream"); ok {
t.Fatal("overflowed stream should be removed from runtime")
}
}
func TestServerDetachLogicalSessionTransportResetsScopedStreams(t *testing.T) {
server := NewServer().(*ServerCommon)
runtime := server.getStreamRuntime()
client := &ClientConn{
ClientID: "detached-peer",
server: server,
}
logical := logicalConnFromClient(client)
scope := serverFileScope(logical)
stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{
StreamID: "detach-stream",
Channel: StreamDataChannel,
ReadTimeout: 20 * time.Millisecond,
}, 0, logical, &TransportConn{
logical: logical,
generation: 1,
remoteAddr: streamTestAddr("detach"),
attached: true,
}, 1, nil, nil, nil, defaultStreamConfig())
if err := runtime.register(scope, stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical.startSession(left, nil, nil)
server.detachLogicalSessionTransport(logical, "read error", errors.New("boom"))
readErr := readStreamError(t, stream, time.Second)
if !errors.Is(readErr, errTransportDetached) {
t.Fatalf("detached stream read error = %v, want %v", readErr, errTransportDetached)
}
if _, ok := runtime.lookup(scope, "detach-stream"); ok {
t.Fatal("detached stream should be removed from runtime")
}
}
func TestGetStreamSnapshotsIncludesBufferedState(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getStreamRuntime()
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "snapshot-stream",
Channel: StreamControlChannel,
ReadTimeout: time.Second,
WriteTimeout: 2 * time.Second,
Metadata: StreamMetadata{
"name": "snapshot-demo",
},
}, 7, nil, nil, 0, nil, nil, nil, defaultStreamConfig())
stream.setClientSnapshotOwner(client)
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register snapshot stream failed: %v", err)
}
if err := stream.pushChunk([]byte("hello")); err != nil {
t.Fatalf("pushChunk failed: %v", err)
}
snapshots, err := GetClientStreamSnapshots(client)
if err != nil {
t.Fatalf("GetClientStreamSnapshots failed: %v", err)
}
if got, want := len(snapshots), 1; got != want {
t.Fatalf("stream snapshot count = %d, want %d", got, want)
}
snapshot := snapshots[0]
if got, want := snapshot.ID, "snapshot-stream"; got != want {
t.Fatalf("snapshot ID = %q, want %q", got, want)
}
if got, want := snapshot.Scope, clientFileScope(); got != want {
t.Fatalf("snapshot Scope = %q, want %q", got, want)
}
if got, want := snapshot.Channel, StreamControlChannel; got != want {
t.Fatalf("snapshot Channel = %q, want %q", got, want)
}
if got, want := snapshot.SessionEpoch, uint64(7); got != want {
t.Fatalf("snapshot SessionEpoch = %d, want %d", got, want)
}
if got, want := snapshot.BufferedChunks, 1; got != want {
t.Fatalf("snapshot BufferedChunks = %d, want %d", got, want)
}
if got, want := snapshot.BufferedBytes, 5; got != want {
t.Fatalf("snapshot BufferedBytes = %d, want %d", got, want)
}
if snapshot.LocalReadClosed {
t.Fatal("snapshot LocalReadClosed should be false")
}
if snapshot.PeerReadClosed {
t.Fatal("snapshot PeerReadClosed should be false")
}
if got := snapshot.Metadata["name"]; got != "snapshot-demo" {
t.Fatalf("snapshot metadata mismatch: %+v", snapshot.Metadata)
}
if got, want := snapshot.ReadTimeout, time.Second; got != want {
t.Fatalf("snapshot ReadTimeout = %v, want %v", got, want)
}
if got, want := snapshot.WriteTimeout, 2*time.Second; got != want {
t.Fatalf("snapshot WriteTimeout = %v, want %v", got, want)
}
if got, want := snapshot.BindingOwner, "client-session"; got != want {
t.Fatalf("snapshot BindingOwner = %q, want %q", got, want)
}
}
func TestGetStreamSnapshotsIncludesIOObservability(t *testing.T) {
runtime := newStreamRuntime("snapshot-observe")
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "snapshot-observe-stream",
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
return nil
}, runtime.configSnapshot())
stream.setAddrSnapshot(streamTestAddr("local-addr"), streamTestAddr("remote-addr"))
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
if err := stream.pushChunk([]byte("hello")); err != nil {
t.Fatalf("pushChunk failed: %v", err)
}
buf := make([]byte, 2)
if _, err := stream.Read(buf); err != nil {
t.Fatalf("stream Read failed: %v", err)
}
if _, err := stream.Write([]byte("world")); err != nil {
t.Fatalf("stream Write failed: %v", err)
}
readDeadline := time.Now().Add(time.Minute).Round(0)
writeDeadline := time.Now().Add(2 * time.Minute).Round(0)
if err := stream.SetReadDeadline(readDeadline); err != nil {
t.Fatalf("SetReadDeadline failed: %v", err)
}
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
t.Fatalf("SetWriteDeadline failed: %v", err)
}
snapshots := runtime.snapshots()
if got, want := len(snapshots), 1; got != want {
t.Fatalf("snapshot count = %d, want %d", got, want)
}
snapshot := snapshots[0]
if got, want := snapshot.LocalAddress, "local-addr"; got != want {
t.Fatalf("snapshot LocalAddress = %q, want %q", got, want)
}
if got, want := snapshot.RemoteAddress, "remote-addr"; got != want {
t.Fatalf("snapshot RemoteAddress = %q, want %q", got, want)
}
if got, want := snapshot.BytesRead, int64(2); got != want {
t.Fatalf("snapshot BytesRead = %d, want %d", got, want)
}
if got, want := snapshot.BytesWritten, int64(5); got != want {
t.Fatalf("snapshot BytesWritten = %d, want %d", got, want)
}
if got, want := snapshot.ReadCalls, int64(1); got != want {
t.Fatalf("snapshot ReadCalls = %d, want %d", got, want)
}
if got, want := snapshot.WriteCalls, int64(1); got != want {
t.Fatalf("snapshot WriteCalls = %d, want %d", got, want)
}
if snapshot.OpenedAt.IsZero() {
t.Fatal("snapshot OpenedAt should not be zero")
}
if snapshot.LastReadAt.IsZero() {
t.Fatal("snapshot LastReadAt should not be zero")
}
if snapshot.LastWriteAt.IsZero() {
t.Fatal("snapshot LastWriteAt should not be zero")
}
if got, want := snapshot.ReadDeadline, readDeadline; !got.Equal(want) {
t.Fatalf("snapshot ReadDeadline = %v, want %v", got, want)
}
if got, want := snapshot.WriteDeadline, writeDeadline; !got.Equal(want) {
t.Fatalf("snapshot WriteDeadline = %v, want %v", got, want)
}
}
func TestStreamSnapshotIncludesDetachedBindingDiagnostics(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer right.Close()
logical := server.bootstrapAcceptedLogical("stream-snapshot-detach", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
transport := logical.CurrentTransportConn()
if transport == nil {
t.Fatal("CurrentTransportConn should return active transport")
}
stream := newStreamHandle(context.Background(), newStreamRuntime("snapshot-detach"), serverFileScope(logical), StreamOpenRequest{
StreamID: "stream-snapshot-detach",
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
server.detachLogicalSessionTransport(logical, "read error", errors.New("boom"))
snapshot := stream.snapshot()
if got, want := snapshot.BindingOwner, "server-transport"; got != want {
t.Fatalf("snapshot BindingOwner = %q, want %q", got, want)
}
if snapshot.BindingCurrent {
t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot)
}
if snapshot.TransportAttached {
t.Fatalf("snapshot TransportAttached should be false after detach: %+v", snapshot)
}
if snapshot.TransportCurrent {
t.Fatalf("snapshot TransportCurrent should be false after detach: %+v", snapshot)
}
if got, want := snapshot.TransportDetachReason, "read error"; got != want {
t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want)
}
if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want {
t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want)
}
if got, want := snapshot.TransportDetachError, "boom"; got != want {
t.Fatalf("snapshot TransportDetachError = %q, want %q", got, want)
}
}
func waitForStreamReadEOF(t *testing.T, stream Stream, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
buf := make([]byte, 1)
for time.Now().Before(deadline) {
_, err := stream.Read(buf)
if errors.Is(err, io.EOF) {
return
}
if err != nil && !errors.Is(err, errStreamDataPathNotReady) {
t.Fatalf("stream Read returned unexpected error: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for stream EOF")
}
func waitForStreamContextDone(t *testing.T, ctx context.Context, timeout time.Duration) {
t.Helper()
select {
case <-ctx.Done():
case <-time.After(timeout):
t.Fatal("timed out waiting for stream context done")
}
}
func waitAcceptedStream(t *testing.T, ch <-chan StreamAcceptInfo, timeout time.Duration) StreamAcceptInfo {
t.Helper()
select {
case info := <-ch:
return info
case <-time.After(timeout):
t.Fatal("timed out waiting for accepted stream")
return StreamAcceptInfo{}
}
}
func readStreamExactly(t *testing.T, stream Stream, want string, timeout time.Duration) {
t.Helper()
errCh := make(chan error, 1)
go func() {
buf := make([]byte, len(want))
_, err := io.ReadFull(stream, buf)
if err != nil {
errCh <- err
return
}
if got := string(buf); got != want {
errCh <- errors.New("stream payload mismatch: got " + got + " want " + want)
return
}
errCh <- nil
}()
select {
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
case <-time.After(timeout):
t.Fatal("timed out waiting for stream payload")
}
}
func readStreamError(t *testing.T, stream Stream, timeout time.Duration) error {
t.Helper()
errCh := make(chan error, 1)
go func() {
buf := make([]byte, 1)
_, err := stream.Read(buf)
errCh <- err
}()
select {
case err := <-errCh:
if err == nil {
t.Fatal("expected stream read error, got nil")
}
return err
case <-time.After(timeout):
t.Fatal("timed out waiting for stream read error")
return nil
}
}
func readServerEnvelopeFromConn(t *testing.T, server *ServerCommon, logical *LogicalConn, conn net.Conn, timeout time.Duration) Envelope {
t.Helper()
queue := stario.NewQueue()
deadline := time.Now().Add(timeout)
buf := make([]byte, 4096)
for time.Now().Before(deadline) {
if err := conn.SetReadDeadline(deadline); err != nil {
t.Fatalf("SetReadDeadline failed: %v", err)
}
n, err := conn.Read(buf)
if n > 0 {
if parseErr := queue.ParseMessage(buf[:n], "stream-test"); parseErr != nil {
t.Fatalf("ParseMessage failed: %v", parseErr)
}
select {
case msg := <-queue.RestoreChan():
env, decErr := server.decodeEnvelopeLogical(logical, msg.Msg)
if decErr != nil {
t.Fatalf("decodeEnvelopeLogical failed: %v", decErr)
}
return env
default:
}
}
if err == nil {
continue
}
if errors.Is(err, os.ErrDeadlineExceeded) {
break
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
break
}
t.Fatalf("conn Read failed: %v", err)
}
t.Fatal("timed out waiting for server envelope")
return Envelope{}
}
type streamTestAddr string
func (a streamTestAddr) Network() string {
return "stream-test"
}
func (a streamTestAddr) String() string {
return string(a)
}
var _ net.Addr = streamTestAddr("")