1180 lines
38 KiB
Go
1180 lines
38 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"b612.me/stario"
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"math"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestStreamOpenRoundTripTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
Metadata: StreamMetadata{
|
||
|
|
"name": "demo.bin",
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var accepted StreamAcceptInfo
|
||
|
|
select {
|
||
|
|
case accepted = <-acceptCh:
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for accepted stream")
|
||
|
|
}
|
||
|
|
|
||
|
|
if accepted.ID != stream.ID() {
|
||
|
|
t.Fatalf("accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID())
|
||
|
|
}
|
||
|
|
if accepted.Channel != StreamDataChannel {
|
||
|
|
t.Fatalf("accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamDataChannel)
|
||
|
|
}
|
||
|
|
if accepted.Metadata["name"] != "demo.bin" {
|
||
|
|
t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata)
|
||
|
|
}
|
||
|
|
if accepted.LogicalConn == nil {
|
||
|
|
t.Fatal("accepted logical connection should not be nil")
|
||
|
|
}
|
||
|
|
if accepted.TransportConn == nil {
|
||
|
|
t.Fatal("accepted transport connection should not be nil")
|
||
|
|
}
|
||
|
|
clientHandle, ok := stream.(*streamHandle)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("stream type = %T, want *streamHandle", stream)
|
||
|
|
}
|
||
|
|
if accepted.DataID == 0 {
|
||
|
|
t.Fatal("accepted stream data id should not be zero")
|
||
|
|
}
|
||
|
|
if got, want := clientHandle.dataIDSnapshot(), accepted.DataID; got != want {
|
||
|
|
t.Fatalf("client stream data id = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, err := stream.Write([]byte("hello-from-client")); err != nil {
|
||
|
|
t.Fatalf("client stream Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, accepted.Stream, "hello-from-client", 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := accepted.Stream.Write([]byte("hello-from-server")); err != nil {
|
||
|
|
t.Fatalf("server accepted stream Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, stream, "hello-from-server", 2*time.Second)
|
||
|
|
|
||
|
|
if err := stream.Close(); err != nil {
|
||
|
|
t.Fatalf("client stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
|
||
|
|
|
||
|
|
if err := accepted.Stream.Close(); err != nil {
|
||
|
|
t.Fatalf("server accepted stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamCloseWriteKeepsReadSideAliveTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
|
||
|
|
if err := accepted.Stream.CloseWrite(); err != nil {
|
||
|
|
t.Fatalf("server accepted stream CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
waitForStreamReadEOF(t, stream, 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := stream.Write([]byte("client-after-peer-close")); err != nil {
|
||
|
|
t.Fatalf("client stream Write after peer CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, accepted.Stream, "client-after-peer-close", 2*time.Second)
|
||
|
|
|
||
|
|
if err := stream.CloseWrite(); err != nil {
|
||
|
|
t.Fatalf("client stream CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
|
||
|
|
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamCloseFullStopsPeerWritesTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
|
||
|
|
if err := accepted.Stream.Close(); err != nil {
|
||
|
|
t.Fatalf("server accepted stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
waitForStreamReadEOF(t, stream, 2*time.Second)
|
||
|
|
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := stream.Write([]byte("client-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) {
|
||
|
|
t.Fatalf("client stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamCloseAfterCloseWriteStopsPeerWritesTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
|
||
|
|
if err := stream.CloseWrite(); err != nil {
|
||
|
|
t.Fatalf("client stream CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := accepted.Stream.Write([]byte("server-can-still-reply")); err != nil {
|
||
|
|
t.Fatalf("server accepted stream Write after peer CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, stream, "server-can-still-reply", 2*time.Second)
|
||
|
|
|
||
|
|
if err := stream.Close(); err != nil {
|
||
|
|
t.Fatalf("client stream Close after CloseWrite failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
|
||
|
|
waitForStreamContextDone(t, accepted.Stream.Context(), 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := accepted.Stream.Write([]byte("server-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) {
|
||
|
|
t.Fatalf("server accepted stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamWritePrefersResetErrorOverContextCanceled(t *testing.T) {
|
||
|
|
wantErr := errors.New("remote stream reset")
|
||
|
|
runtime := newStreamRuntime("stream-reset")
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{
|
||
|
|
StreamID: "stream-reset-propagation",
|
||
|
|
DataID: 1,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, func(ctx context.Context, s *streamHandle, chunk []byte) error {
|
||
|
|
s.markReset(wantErr)
|
||
|
|
<-ctx.Done()
|
||
|
|
return ctx.Err()
|
||
|
|
}, streamConfig{ChunkSize: 4})
|
||
|
|
|
||
|
|
_, err := stream.Write([]byte("abcdefgh"))
|
||
|
|
if !errors.Is(err, wantErr) {
|
||
|
|
t.Fatalf("stream Write error = %v, want %v", err, wantErr)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamWriteWaitingBudgetPrefersClosedPipeOverContextCanceled(t *testing.T) {
|
||
|
|
cfg := streamConfig{
|
||
|
|
ChunkSize: 4,
|
||
|
|
OutboundWindowBytes: 4,
|
||
|
|
OutboundMaxInFlightChunks: 1,
|
||
|
|
}
|
||
|
|
runtime := newStreamRuntime("stream-budget-close")
|
||
|
|
runtime.applyConfig(cfg)
|
||
|
|
release, err := runtime.acquireOutbound(context.Background(), 4)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("acquireOutbound setup failed: %v", err)
|
||
|
|
}
|
||
|
|
defer release()
|
||
|
|
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{
|
||
|
|
StreamID: "stream-budget-close",
|
||
|
|
DataID: 1,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
|
||
|
|
return nil
|
||
|
|
}, cfg)
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
_, err := stream.Write([]byte("abcd"))
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
stream.markPeerClosed()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||
|
|
t.Fatalf("stream Write error = %v, want %v", err, io.ErrClosedPipe)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("stream Write did not return after peer close")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamReadWaitingLocalClosePrefersClosedPipeOverContextCanceled(t *testing.T) {
|
||
|
|
stream := newStreamHandle(context.Background(), nil, "test", StreamOpenRequest{
|
||
|
|
StreamID: "stream-read-local-close",
|
||
|
|
DataID: 1,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
buf := make([]byte, 4)
|
||
|
|
_, err := stream.Read(buf)
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
if err := stream.Close(); err != nil {
|
||
|
|
t.Fatalf("stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||
|
|
t.Fatalf("stream Read error = %v, want %v", err, io.ErrClosedPipe)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("stream Read did not return after local close")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamOpenRoundTripServerToClientTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
client.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||
|
|
stream, err := server.OpenStreamLogical(context.Background(), logical, StreamOpenOptions{
|
||
|
|
Channel: StreamControlChannel,
|
||
|
|
Metadata: StreamMetadata{
|
||
|
|
"purpose": "server-open",
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("server OpenStreamLogical failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var accepted StreamAcceptInfo
|
||
|
|
select {
|
||
|
|
case accepted = <-acceptCh:
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for client accepted stream")
|
||
|
|
}
|
||
|
|
|
||
|
|
if accepted.ID != stream.ID() {
|
||
|
|
t.Fatalf("client accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID())
|
||
|
|
}
|
||
|
|
if accepted.Channel != StreamControlChannel {
|
||
|
|
t.Fatalf("client accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamControlChannel)
|
||
|
|
}
|
||
|
|
if accepted.Metadata["purpose"] != "server-open" {
|
||
|
|
t.Fatalf("client accepted metadata mismatch: %+v", accepted.Metadata)
|
||
|
|
}
|
||
|
|
if accepted.LogicalConn != nil {
|
||
|
|
t.Fatalf("client accepted logical connection should be nil: %+v", accepted.LogicalConn)
|
||
|
|
}
|
||
|
|
serverHandle, ok := stream.(*streamHandle)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("stream type = %T, want *streamHandle", stream)
|
||
|
|
}
|
||
|
|
if accepted.DataID == 0 {
|
||
|
|
t.Fatal("client accepted stream data id should not be zero")
|
||
|
|
}
|
||
|
|
if got, want := serverHandle.dataIDSnapshot(), accepted.DataID; got != want {
|
||
|
|
t.Fatalf("server stream data id = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, err := stream.Write([]byte("server-opened")); err != nil {
|
||
|
|
t.Fatalf("server stream Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, accepted.Stream, "server-opened", 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := accepted.Stream.Write([]byte("client-accepted")); err != nil {
|
||
|
|
t.Fatalf("client accepted stream Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, stream, "client-accepted", 2*time.Second)
|
||
|
|
|
||
|
|
if err := stream.Close(); err != nil {
|
||
|
|
t.Fatalf("server stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
waitForStreamReadEOF(t, accepted.Stream, 2*time.Second)
|
||
|
|
|
||
|
|
if err := accepted.Stream.Close(); err != nil {
|
||
|
|
t.Fatalf("client accepted stream Close failed: %v", err)
|
||
|
|
}
|
||
|
|
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamResetRoundTripTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
|
||
|
|
resetCause := errors.New("stream-reset-by-server")
|
||
|
|
if err := accepted.Stream.Reset(resetCause); err != nil {
|
||
|
|
t.Fatalf("server accepted stream Reset failed: %v", err)
|
||
|
|
}
|
||
|
|
readErr := readStreamError(t, stream, 2*time.Second)
|
||
|
|
if !strings.Contains(readErr.Error(), resetCause.Error()) {
|
||
|
|
t.Fatalf("stream Read reset error mismatch: got %v want %q", readErr, resetCause.Error())
|
||
|
|
}
|
||
|
|
waitForStreamContextDone(t, stream.Context(), 2*time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamSetReadDeadlineUnblocksPendingRead(t *testing.T) {
|
||
|
|
runtime := newStreamRuntime("read-deadline")
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "read-deadline-stream",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
ReadTimeout: time.Second,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
buf := make([]byte, 1)
|
||
|
|
_, err := stream.Read(buf)
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
if err := stream.SetReadDeadline(time.Now().Add(40 * time.Millisecond)); err != nil {
|
||
|
|
t.Fatalf("SetReadDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||
|
|
t.Fatalf("stream Read error = %v, want %v", err, os.ErrDeadlineExceeded)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for read deadline")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamSetWriteDeadlineUnblocksBlockedWrite(t *testing.T) {
|
||
|
|
runtime := newStreamRuntime("write-deadline")
|
||
|
|
runtime.applyConfig(streamConfig{
|
||
|
|
ChunkSize: 4,
|
||
|
|
InboundQueueLimit: defaultStreamInboundQueueLimit,
|
||
|
|
InboundBufferedBytesLimit: defaultStreamInboundBufferedBytesLimit,
|
||
|
|
OutboundWindowBytes: 4,
|
||
|
|
OutboundMaxInFlightChunks: 1,
|
||
|
|
})
|
||
|
|
|
||
|
|
holdCtx, holdCancel := context.WithCancel(context.Background())
|
||
|
|
defer holdCancel()
|
||
|
|
release, err := runtime.acquireOutbound(holdCtx, 4)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("acquireOutbound failed: %v", err)
|
||
|
|
}
|
||
|
|
defer release()
|
||
|
|
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "write-deadline-stream",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
WriteTimeout: time.Second,
|
||
|
|
ReadTimeout: time.Second,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
|
||
|
|
return nil
|
||
|
|
}, runtime.configSnapshot())
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
_, err := stream.Write([]byte("abcd"))
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
if err := stream.SetWriteDeadline(time.Now().Add(40 * time.Millisecond)); err != nil {
|
||
|
|
t.Fatalf("SetWriteDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||
|
|
t.Fatalf("stream Write error = %v, want %v", err, os.ErrDeadlineExceeded)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for write deadline")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamImplementsNetConnTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("client OpenStream failed: %v", err)
|
||
|
|
}
|
||
|
|
accepted := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
|
||
|
|
var clientConn net.Conn = stream
|
||
|
|
var serverConn net.Conn = accepted.Stream
|
||
|
|
|
||
|
|
if clientConn.LocalAddr() == nil || clientConn.RemoteAddr() == nil {
|
||
|
|
t.Fatalf("client stream net.Conn addrs missing: local=%v remote=%v", clientConn.LocalAddr(), clientConn.RemoteAddr())
|
||
|
|
}
|
||
|
|
if serverConn.LocalAddr() == nil || serverConn.RemoteAddr() == nil {
|
||
|
|
t.Fatalf("server stream net.Conn addrs missing: local=%v remote=%v", serverConn.LocalAddr(), serverConn.RemoteAddr())
|
||
|
|
}
|
||
|
|
if err := clientConn.SetDeadline(time.Now().Add(time.Second)); err != nil {
|
||
|
|
t.Fatalf("client stream SetDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := serverConn.SetDeadline(time.Now().Add(time.Second)); err != nil {
|
||
|
|
t.Fatalf("server stream SetDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, err := clientConn.Write([]byte("from-net-conn-client")); err != nil {
|
||
|
|
t.Fatalf("client net.Conn Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, accepted.Stream, "from-net-conn-client", 2*time.Second)
|
||
|
|
|
||
|
|
if _, err := serverConn.Write([]byte("from-net-conn-server")); err != nil {
|
||
|
|
t.Fatalf("server net.Conn Write failed: %v", err)
|
||
|
|
}
|
||
|
|
readStreamExactly(t, stream, "from-net-conn-server", 2*time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestClientDispatchStreamEnvelopeRejectsStaleSessionEpoch(t *testing.T) {
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
runtime := client.getStreamRuntime()
|
||
|
|
|
||
|
|
staleEpoch := client.beginClientSessionEpoch()
|
||
|
|
_ = client.beginClientSessionEpoch()
|
||
|
|
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "client-stale",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
ReadTimeout: 20 * time.Millisecond,
|
||
|
|
}, staleEpoch, nil, nil, 0, nil, nil, nil, defaultStreamConfig())
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register stale client stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
client.dispatchStreamEnvelope(newStreamDataEnvelope("client-stale", []byte("payload")))
|
||
|
|
|
||
|
|
readErr := readStreamError(t, stream, time.Second)
|
||
|
|
if !errors.Is(readErr, errTransportDetached) {
|
||
|
|
t.Fatalf("stale client stream read error mismatch: got %v want %v", readErr, errTransportDetached)
|
||
|
|
}
|
||
|
|
waitForStreamContextDone(t, stream.Context(), time.Second)
|
||
|
|
|
||
|
|
if _, ok := runtime.lookup(clientFileScope(), "client-stale"); ok {
|
||
|
|
t.Fatal("stale client stream should be removed from runtime")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatch(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
runtime := server.getStreamRuntime()
|
||
|
|
clientConn := &ClientConn{
|
||
|
|
ClientID: "server-stale-peer",
|
||
|
|
server: server,
|
||
|
|
}
|
||
|
|
logical := logicalConnFromClient(clientConn)
|
||
|
|
scope := serverFileScope(logical)
|
||
|
|
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{
|
||
|
|
StreamID: "server-stale",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
ReadTimeout: 20 * time.Millisecond,
|
||
|
|
}, 0, logical, &TransportConn{
|
||
|
|
logical: logical,
|
||
|
|
generation: 1,
|
||
|
|
remoteAddr: streamTestAddr("current"),
|
||
|
|
attached: true,
|
||
|
|
}, 1, nil, nil, nil, defaultStreamConfig())
|
||
|
|
if err := runtime.register(scope, stream); err != nil {
|
||
|
|
t.Fatalf("register server stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
server.dispatchStreamEnvelope(logical, &TransportConn{
|
||
|
|
logical: logical,
|
||
|
|
generation: 2,
|
||
|
|
remoteAddr: streamTestAddr("stale"),
|
||
|
|
attached: true,
|
||
|
|
}, nil, newStreamDataEnvelope("server-stale", []byte("stale-payload")))
|
||
|
|
|
||
|
|
readErr := readStreamError(t, stream, time.Second)
|
||
|
|
if !errors.Is(readErr, os.ErrDeadlineExceeded) {
|
||
|
|
t.Fatalf("server stale generation read error mismatch: got %v want %v", readErr, os.ErrDeadlineExceeded)
|
||
|
|
}
|
||
|
|
|
||
|
|
server.dispatchStreamEnvelope(logical, &TransportConn{
|
||
|
|
logical: logical,
|
||
|
|
generation: 1,
|
||
|
|
remoteAddr: streamTestAddr("current"),
|
||
|
|
attached: true,
|
||
|
|
}, nil, newStreamDataEnvelope("server-stale", []byte("good-payload")))
|
||
|
|
|
||
|
|
readStreamExactly(t, stream, "good-payload", time.Second)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatchWritesResetViaInboundConn(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
UseLegacySecurityServer(server)
|
||
|
|
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||
|
|
defer runtimeCancel()
|
||
|
|
server.setServerSessionRuntime(&serverSessionRuntime{
|
||
|
|
stopCtx: runtimeCtx,
|
||
|
|
stopFn: runtimeCancel,
|
||
|
|
queue: stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32),
|
||
|
|
})
|
||
|
|
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical := server.bootstrapAcceptedLogical("server-stream-reset", nil, left)
|
||
|
|
if logical == nil {
|
||
|
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||
|
|
}
|
||
|
|
|
||
|
|
transport := logical.CurrentTransportConn()
|
||
|
|
if transport == nil {
|
||
|
|
t.Fatal("current transport should exist")
|
||
|
|
}
|
||
|
|
|
||
|
|
runtime := server.getStreamRuntime()
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, serverFileScope(logical), StreamOpenRequest{
|
||
|
|
StreamID: "server-stale-reset",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||
|
|
if err := runtime.register(serverFileScope(logical), stream); err != nil {
|
||
|
|
t.Fatalf("register server stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
staleTransport := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration()+1, true)
|
||
|
|
if staleTransport == nil {
|
||
|
|
t.Fatal("stale transport snapshot should exist")
|
||
|
|
}
|
||
|
|
|
||
|
|
done := make(chan struct{})
|
||
|
|
go func() {
|
||
|
|
server.dispatchStreamEnvelope(logical, staleTransport, left, newStreamDataEnvelope("server-stale-reset", []byte("stale-payload")))
|
||
|
|
close(done)
|
||
|
|
}()
|
||
|
|
|
||
|
|
env := readServerEnvelopeFromConn(t, server, logical, right, time.Second)
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for stream dispatch to finish")
|
||
|
|
}
|
||
|
|
if env.Kind != EnvelopeSignal {
|
||
|
|
t.Fatalf("reset envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
|
||
|
|
}
|
||
|
|
transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err)
|
||
|
|
}
|
||
|
|
if transfer.Key != StreamResetSignalKey {
|
||
|
|
t.Fatalf("reset transfer key = %q, want %q", transfer.Key, StreamResetSignalKey)
|
||
|
|
}
|
||
|
|
if transfer.Type != MSG_ASYNC {
|
||
|
|
t.Fatalf("reset transfer type = %v, want %v", transfer.Type, MSG_ASYNC)
|
||
|
|
}
|
||
|
|
var req StreamResetRequest
|
||
|
|
if err := transfer.Value.Orm(&req); err != nil {
|
||
|
|
t.Fatalf("decode reset request failed: %v", err)
|
||
|
|
}
|
||
|
|
if req.StreamID != "server-stale-reset" {
|
||
|
|
t.Fatalf("reset stream id = %q, want %q", req.StreamID, "server-stale-reset")
|
||
|
|
}
|
||
|
|
if !strings.HasPrefix(req.Error, errTransportDetached.Error()) {
|
||
|
|
t.Fatalf("reset error = %q, want prefix %q", req.Error, errTransportDetached.Error())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamBackpressureOverflowResetsStreamAndRemovesRuntimeEntry(t *testing.T) {
|
||
|
|
runtime := newStreamRuntime("overflow")
|
||
|
|
runtime.applyConfig(streamConfig{
|
||
|
|
ChunkSize: 4,
|
||
|
|
InboundQueueLimit: 1,
|
||
|
|
InboundBufferedBytesLimit: 4,
|
||
|
|
})
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "overflow-stream",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
ReadTimeout: 20 * time.Millisecond,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := stream.pushChunk([]byte("abcd")); err != nil {
|
||
|
|
t.Fatalf("first pushChunk failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := stream.pushChunk([]byte("ef")); !errors.Is(err, errStreamBackpressureExceeded) {
|
||
|
|
t.Fatalf("overflow pushChunk error = %v, want %v", err, errStreamBackpressureExceeded)
|
||
|
|
}
|
||
|
|
readErr := readStreamError(t, stream, time.Second)
|
||
|
|
if !errors.Is(readErr, errStreamBackpressureExceeded) {
|
||
|
|
t.Fatalf("stream read error = %v, want %v", readErr, errStreamBackpressureExceeded)
|
||
|
|
}
|
||
|
|
if _, ok := runtime.lookup(clientFileScope(), "overflow-stream"); ok {
|
||
|
|
t.Fatal("overflowed stream should be removed from runtime")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerDetachLogicalSessionTransportResetsScopedStreams(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
runtime := server.getStreamRuntime()
|
||
|
|
client := &ClientConn{
|
||
|
|
ClientID: "detached-peer",
|
||
|
|
server: server,
|
||
|
|
}
|
||
|
|
logical := logicalConnFromClient(client)
|
||
|
|
scope := serverFileScope(logical)
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{
|
||
|
|
StreamID: "detach-stream",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
ReadTimeout: 20 * time.Millisecond,
|
||
|
|
}, 0, logical, &TransportConn{
|
||
|
|
logical: logical,
|
||
|
|
generation: 1,
|
||
|
|
remoteAddr: streamTestAddr("detach"),
|
||
|
|
attached: true,
|
||
|
|
}, 1, nil, nil, nil, defaultStreamConfig())
|
||
|
|
if err := runtime.register(scope, stream); err != nil {
|
||
|
|
t.Fatalf("register stream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
logical.startSession(left, nil, nil)
|
||
|
|
server.detachLogicalSessionTransport(logical, "read error", errors.New("boom"))
|
||
|
|
|
||
|
|
readErr := readStreamError(t, stream, time.Second)
|
||
|
|
if !errors.Is(readErr, errTransportDetached) {
|
||
|
|
t.Fatalf("detached stream read error = %v, want %v", readErr, errTransportDetached)
|
||
|
|
}
|
||
|
|
if _, ok := runtime.lookup(scope, "detach-stream"); ok {
|
||
|
|
t.Fatal("detached stream should be removed from runtime")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGetStreamSnapshotsIncludesBufferedState(t *testing.T) {
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
runtime := client.getStreamRuntime()
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "snapshot-stream",
|
||
|
|
Channel: StreamControlChannel,
|
||
|
|
ReadTimeout: time.Second,
|
||
|
|
WriteTimeout: 2 * time.Second,
|
||
|
|
Metadata: StreamMetadata{
|
||
|
|
"name": "snapshot-demo",
|
||
|
|
},
|
||
|
|
}, 7, nil, nil, 0, nil, nil, nil, defaultStreamConfig())
|
||
|
|
stream.setClientSnapshotOwner(client)
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register snapshot stream failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := stream.pushChunk([]byte("hello")); err != nil {
|
||
|
|
t.Fatalf("pushChunk failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
snapshots, err := GetClientStreamSnapshots(client)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientStreamSnapshots failed: %v", err)
|
||
|
|
}
|
||
|
|
if got, want := len(snapshots), 1; got != want {
|
||
|
|
t.Fatalf("stream snapshot count = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
snapshot := snapshots[0]
|
||
|
|
if got, want := snapshot.ID, "snapshot-stream"; got != want {
|
||
|
|
t.Fatalf("snapshot ID = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.Scope, clientFileScope(); got != want {
|
||
|
|
t.Fatalf("snapshot Scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.Channel, StreamControlChannel; got != want {
|
||
|
|
t.Fatalf("snapshot Channel = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.SessionEpoch, uint64(7); got != want {
|
||
|
|
t.Fatalf("snapshot SessionEpoch = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.BufferedChunks, 1; got != want {
|
||
|
|
t.Fatalf("snapshot BufferedChunks = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.BufferedBytes, 5; got != want {
|
||
|
|
t.Fatalf("snapshot BufferedBytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if snapshot.LocalReadClosed {
|
||
|
|
t.Fatal("snapshot LocalReadClosed should be false")
|
||
|
|
}
|
||
|
|
if snapshot.PeerReadClosed {
|
||
|
|
t.Fatal("snapshot PeerReadClosed should be false")
|
||
|
|
}
|
||
|
|
if got := snapshot.Metadata["name"]; got != "snapshot-demo" {
|
||
|
|
t.Fatalf("snapshot metadata mismatch: %+v", snapshot.Metadata)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.ReadTimeout, time.Second; got != want {
|
||
|
|
t.Fatalf("snapshot ReadTimeout = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.WriteTimeout, 2*time.Second; got != want {
|
||
|
|
t.Fatalf("snapshot WriteTimeout = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.BindingOwner, "client-session"; got != want {
|
||
|
|
t.Fatalf("snapshot BindingOwner = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGetStreamSnapshotsIncludesIOObservability(t *testing.T) {
|
||
|
|
runtime := newStreamRuntime("snapshot-observe")
|
||
|
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||
|
|
StreamID: "snapshot-observe-stream",
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
}, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error {
|
||
|
|
return nil
|
||
|
|
}, runtime.configSnapshot())
|
||
|
|
stream.setAddrSnapshot(streamTestAddr("local-addr"), streamTestAddr("remote-addr"))
|
||
|
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||
|
|
t.Fatalf("register stream failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := stream.pushChunk([]byte("hello")); err != nil {
|
||
|
|
t.Fatalf("pushChunk failed: %v", err)
|
||
|
|
}
|
||
|
|
buf := make([]byte, 2)
|
||
|
|
if _, err := stream.Read(buf); err != nil {
|
||
|
|
t.Fatalf("stream Read failed: %v", err)
|
||
|
|
}
|
||
|
|
if _, err := stream.Write([]byte("world")); err != nil {
|
||
|
|
t.Fatalf("stream Write failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
readDeadline := time.Now().Add(time.Minute).Round(0)
|
||
|
|
writeDeadline := time.Now().Add(2 * time.Minute).Round(0)
|
||
|
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||
|
|
t.Fatalf("SetReadDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||
|
|
t.Fatalf("SetWriteDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
snapshots := runtime.snapshots()
|
||
|
|
if got, want := len(snapshots), 1; got != want {
|
||
|
|
t.Fatalf("snapshot count = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
snapshot := snapshots[0]
|
||
|
|
if got, want := snapshot.LocalAddress, "local-addr"; got != want {
|
||
|
|
t.Fatalf("snapshot LocalAddress = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.RemoteAddress, "remote-addr"; got != want {
|
||
|
|
t.Fatalf("snapshot RemoteAddress = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.BytesRead, int64(2); got != want {
|
||
|
|
t.Fatalf("snapshot BytesRead = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.BytesWritten, int64(5); got != want {
|
||
|
|
t.Fatalf("snapshot BytesWritten = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.ReadCalls, int64(1); got != want {
|
||
|
|
t.Fatalf("snapshot ReadCalls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.WriteCalls, int64(1); got != want {
|
||
|
|
t.Fatalf("snapshot WriteCalls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if snapshot.OpenedAt.IsZero() {
|
||
|
|
t.Fatal("snapshot OpenedAt should not be zero")
|
||
|
|
}
|
||
|
|
if snapshot.LastReadAt.IsZero() {
|
||
|
|
t.Fatal("snapshot LastReadAt should not be zero")
|
||
|
|
}
|
||
|
|
if snapshot.LastWriteAt.IsZero() {
|
||
|
|
t.Fatal("snapshot LastWriteAt should not be zero")
|
||
|
|
}
|
||
|
|
if got, want := snapshot.ReadDeadline, readDeadline; !got.Equal(want) {
|
||
|
|
t.Fatalf("snapshot ReadDeadline = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.WriteDeadline, writeDeadline; !got.Equal(want) {
|
||
|
|
t.Fatalf("snapshot WriteDeadline = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStreamSnapshotIncludesDetachedBindingDiagnostics(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical := server.bootstrapAcceptedLogical("stream-snapshot-detach", nil, left)
|
||
|
|
if logical == nil {
|
||
|
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||
|
|
}
|
||
|
|
transport := logical.CurrentTransportConn()
|
||
|
|
if transport == nil {
|
||
|
|
t.Fatal("CurrentTransportConn should return active transport")
|
||
|
|
}
|
||
|
|
stream := newStreamHandle(context.Background(), newStreamRuntime("snapshot-detach"), serverFileScope(logical), StreamOpenRequest{
|
||
|
|
StreamID: "stream-snapshot-detach",
|
||
|
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||
|
|
|
||
|
|
server.detachLogicalSessionTransport(logical, "read error", errors.New("boom"))
|
||
|
|
|
||
|
|
snapshot := stream.snapshot()
|
||
|
|
if got, want := snapshot.BindingOwner, "server-transport"; got != want {
|
||
|
|
t.Fatalf("snapshot BindingOwner = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if snapshot.BindingCurrent {
|
||
|
|
t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot)
|
||
|
|
}
|
||
|
|
if snapshot.TransportAttached {
|
||
|
|
t.Fatalf("snapshot TransportAttached should be false after detach: %+v", snapshot)
|
||
|
|
}
|
||
|
|
if snapshot.TransportCurrent {
|
||
|
|
t.Fatalf("snapshot TransportCurrent should be false after detach: %+v", snapshot)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.TransportDetachReason, "read error"; got != want {
|
||
|
|
t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want {
|
||
|
|
t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := snapshot.TransportDetachError, "boom"; got != want {
|
||
|
|
t.Fatalf("snapshot TransportDetachError = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForStreamReadEOF(t *testing.T, stream Stream, timeout time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
buf := make([]byte, 1)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
_, err := stream.Read(buf)
|
||
|
|
if errors.Is(err, io.EOF) {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if err != nil && !errors.Is(err, errStreamDataPathNotReady) {
|
||
|
|
t.Fatalf("stream Read returned unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatal("timed out waiting for stream EOF")
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForStreamContextDone(t *testing.T, ctx context.Context, timeout time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
case <-time.After(timeout):
|
||
|
|
t.Fatal("timed out waiting for stream context done")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitAcceptedStream(t *testing.T, ch <-chan StreamAcceptInfo, timeout time.Duration) StreamAcceptInfo {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case info := <-ch:
|
||
|
|
return info
|
||
|
|
case <-time.After(timeout):
|
||
|
|
t.Fatal("timed out waiting for accepted stream")
|
||
|
|
return StreamAcceptInfo{}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func readStreamExactly(t *testing.T, stream Stream, want string, timeout time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
buf := make([]byte, len(want))
|
||
|
|
_, err := io.ReadFull(stream, buf)
|
||
|
|
if err != nil {
|
||
|
|
errCh <- err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if got := string(buf); got != want {
|
||
|
|
errCh <- errors.New("stream payload mismatch: got " + got + " want " + want)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
errCh <- nil
|
||
|
|
}()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
case <-time.After(timeout):
|
||
|
|
t.Fatal("timed out waiting for stream payload")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func readStreamError(t *testing.T, stream Stream, timeout time.Duration) error {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
buf := make([]byte, 1)
|
||
|
|
_, err := stream.Read(buf)
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected stream read error, got nil")
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
case <-time.After(timeout):
|
||
|
|
t.Fatal("timed out waiting for stream read error")
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func readServerEnvelopeFromConn(t *testing.T, server *ServerCommon, logical *LogicalConn, conn net.Conn, timeout time.Duration) Envelope {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
queue := stario.NewQueue()
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
buf := make([]byte, 4096)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
||
|
|
t.Fatalf("SetReadDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
n, err := conn.Read(buf)
|
||
|
|
if n > 0 {
|
||
|
|
if parseErr := queue.ParseMessage(buf[:n], "stream-test"); parseErr != nil {
|
||
|
|
t.Fatalf("ParseMessage failed: %v", parseErr)
|
||
|
|
}
|
||
|
|
select {
|
||
|
|
case msg := <-queue.RestoreChan():
|
||
|
|
env, decErr := server.decodeEnvelopeLogical(logical, msg.Msg)
|
||
|
|
if decErr != nil {
|
||
|
|
t.Fatalf("decodeEnvelopeLogical failed: %v", decErr)
|
||
|
|
}
|
||
|
|
return env
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if err == nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
t.Fatalf("conn Read failed: %v", err)
|
||
|
|
}
|
||
|
|
t.Fatal("timed out waiting for server envelope")
|
||
|
|
return Envelope{}
|
||
|
|
}
|
||
|
|
|
||
|
|
type streamTestAddr string
|
||
|
|
|
||
|
|
func (a streamTestAddr) Network() string {
|
||
|
|
return "stream-test"
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a streamTestAddr) String() string {
|
||
|
|
return string(a)
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ net.Addr = streamTestAddr("")
|