notify/stream_reader_writer_test.go

157 lines
4.8 KiB
Go
Raw Normal View History

package notify
import (
"bytes"
"context"
"testing"
"time"
)
func TestOpenClientStreamFromReaderTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
payloadCh := make(chan []byte, 1)
errCh := make(chan error, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
go func() {
var dst bytes.Buffer
_, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{})
if err != nil {
errCh <- err
return
}
payloadCh <- dst.Bytes()
}()
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() { _ = server.Stop() }()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() { _ = client.Stop() }()
payload := []byte("client-reader-stream-payload")
stream, written, err := OpenClientStreamFromReader(context.Background(), client, bytes.NewReader(payload), StreamOpenCopyOptions{
Open: StreamOpenOptions{
Channel: StreamDataChannel,
Metadata: StreamMetadata{
"name": "reader.bin",
},
},
})
if err != nil {
t.Fatalf("OpenClientStreamFromReader failed: %v", err)
}
if got, want := written, int64(len(payload)); got != want {
t.Fatalf("written = %d, want %d", got, want)
}
info := waitAcceptedStream(t, acceptCh, 2*time.Second)
if info.ID != stream.ID() {
t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID())
}
if got, want := info.Metadata["name"], "reader.bin"; got != want {
t.Fatalf("accepted metadata[name] = %q, want %q", got, want)
}
select {
case err := <-errCh:
t.Fatalf("CopyStreamToWriter failed: %v", err)
case got := <-payloadCh:
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload))
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for copied payload")
}
}
func TestOpenServerLogicalStreamFromReaderTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() { _ = server.Stop() }()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1)
payloadCh := make(chan []byte, 1)
errCh := make(chan error, 1)
client.SetStreamHandler(func(info StreamAcceptInfo) error {
acceptCh <- info
go func() {
var dst bytes.Buffer
_, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{})
if err != nil {
errCh <- err
return
}
payloadCh <- dst.Bytes()
}()
return nil
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() { _ = client.Stop() }()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
payload := []byte("server-logical-reader-stream")
stream, written, err := OpenServerLogicalStreamFromReader(context.Background(), server, logical, bytes.NewReader(payload), StreamOpenCopyOptions{
Open: StreamOpenOptions{
Channel: StreamControlChannel,
Metadata: StreamMetadata{
"role": "server",
},
},
})
if err != nil {
t.Fatalf("OpenServerLogicalStreamFromReader failed: %v", err)
}
if got, want := written, int64(len(payload)); got != want {
t.Fatalf("written = %d, want %d", got, want)
}
info := waitAcceptedStream(t, acceptCh, 2*time.Second)
if info.ID != stream.ID() {
t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID())
}
if got, want := info.Channel, StreamControlChannel; got != want {
t.Fatalf("accepted stream channel = %q, want %q", got, want)
}
if got, want := info.Metadata["role"], "server"; got != want {
t.Fatalf("accepted metadata[role] = %q, want %q", got, want)
}
select {
case err := <-errCh:
t.Fatalf("CopyStreamToWriter failed: %v", err)
case got := <-payloadCh:
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload))
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for copied payload")
}
}