package notify import ( "bytes" "context" "testing" "time" ) func TestOpenClientStreamFromReaderTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) payloadCh := make(chan []byte, 1) errCh := make(chan error, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info go func() { var dst bytes.Buffer _, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{}) if err != nil { errCh <- err return } payloadCh <- dst.Bytes() }() return nil }) if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() payload := []byte("client-reader-stream-payload") stream, written, err := OpenClientStreamFromReader(context.Background(), client, bytes.NewReader(payload), StreamOpenCopyOptions{ Open: StreamOpenOptions{ Channel: StreamDataChannel, Metadata: StreamMetadata{ "name": "reader.bin", }, }, }) if err != nil { t.Fatalf("OpenClientStreamFromReader failed: %v", err) } if got, want := written, int64(len(payload)); got != want { t.Fatalf("written = %d, want %d", got, want) } info := waitAcceptedStream(t, acceptCh, 2*time.Second) if info.ID != stream.ID() { t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID()) } if got, want := info.Metadata["name"], "reader.bin"; got != want { t.Fatalf("accepted metadata[name] = %q, want %q", got, want) } select { case err := <-errCh: t.Fatalf("CopyStreamToWriter failed: %v", err) case got := <-payloadCh: if !bytes.Equal(got, payload) { t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload)) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for copied payload") } } func TestOpenServerLogicalStreamFromReaderTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } acceptCh := make(chan StreamAcceptInfo, 1) payloadCh := make(chan []byte, 1) errCh := make(chan error, 1) client.SetStreamHandler(func(info StreamAcceptInfo) error { acceptCh <- info go func() { var dst bytes.Buffer _, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{}) if err != nil { errCh <- err return } payloadCh <- dst.Bytes() }() return nil }) if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) payload := []byte("server-logical-reader-stream") stream, written, err := OpenServerLogicalStreamFromReader(context.Background(), server, logical, bytes.NewReader(payload), StreamOpenCopyOptions{ Open: StreamOpenOptions{ Channel: StreamControlChannel, Metadata: StreamMetadata{ "role": "server", }, }, }) if err != nil { t.Fatalf("OpenServerLogicalStreamFromReader failed: %v", err) } if got, want := written, int64(len(payload)); got != want { t.Fatalf("written = %d, want %d", got, want) } info := waitAcceptedStream(t, acceptCh, 2*time.Second) if info.ID != stream.ID() { t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID()) } if got, want := info.Channel, StreamControlChannel; got != want { t.Fatalf("accepted stream channel = %q, want %q", got, want) } if got, want := info.Metadata["role"], "server"; got != want { t.Fatalf("accepted metadata[role] = %q, want %q", got, want) } select { case err := <-errCh: t.Fatalf("CopyStreamToWriter failed: %v", err) case got := <-payloadCh: if !bytes.Equal(got, payload) { t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload)) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for copied payload") } }