157 lines
4.8 KiB
Go
157 lines
4.8 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestOpenClientStreamFromReaderTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
payloadCh := make(chan []byte, 1)
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
go func() {
|
||
|
|
var dst bytes.Buffer
|
||
|
|
_, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{})
|
||
|
|
if err != nil {
|
||
|
|
errCh <- err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
payloadCh <- dst.Bytes()
|
||
|
|
}()
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
payload := []byte("client-reader-stream-payload")
|
||
|
|
stream, written, err := OpenClientStreamFromReader(context.Background(), client, bytes.NewReader(payload), StreamOpenCopyOptions{
|
||
|
|
Open: StreamOpenOptions{
|
||
|
|
Channel: StreamDataChannel,
|
||
|
|
Metadata: StreamMetadata{
|
||
|
|
"name": "reader.bin",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("OpenClientStreamFromReader failed: %v", err)
|
||
|
|
}
|
||
|
|
if got, want := written, int64(len(payload)); got != want {
|
||
|
|
t.Fatalf("written = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
info := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
if info.ID != stream.ID() {
|
||
|
|
t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID())
|
||
|
|
}
|
||
|
|
if got, want := info.Metadata["name"], "reader.bin"; got != want {
|
||
|
|
t.Fatalf("accepted metadata[name] = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
t.Fatalf("CopyStreamToWriter failed: %v", err)
|
||
|
|
case got := <-payloadCh:
|
||
|
|
if !bytes.Equal(got, payload) {
|
||
|
|
t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload))
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for copied payload")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOpenServerLogicalStreamFromReaderTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||
|
|
payloadCh := make(chan []byte, 1)
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
client.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||
|
|
acceptCh <- info
|
||
|
|
go func() {
|
||
|
|
var dst bytes.Buffer
|
||
|
|
_, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{})
|
||
|
|
if err != nil {
|
||
|
|
errCh <- err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
payloadCh <- dst.Bytes()
|
||
|
|
}()
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||
|
|
payload := []byte("server-logical-reader-stream")
|
||
|
|
stream, written, err := OpenServerLogicalStreamFromReader(context.Background(), server, logical, bytes.NewReader(payload), StreamOpenCopyOptions{
|
||
|
|
Open: StreamOpenOptions{
|
||
|
|
Channel: StreamControlChannel,
|
||
|
|
Metadata: StreamMetadata{
|
||
|
|
"role": "server",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("OpenServerLogicalStreamFromReader failed: %v", err)
|
||
|
|
}
|
||
|
|
if got, want := written, int64(len(payload)); got != want {
|
||
|
|
t.Fatalf("written = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
info := waitAcceptedStream(t, acceptCh, 2*time.Second)
|
||
|
|
if info.ID != stream.ID() {
|
||
|
|
t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID())
|
||
|
|
}
|
||
|
|
if got, want := info.Channel, StreamControlChannel; got != want {
|
||
|
|
t.Fatalf("accepted stream channel = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := info.Metadata["role"], "server"; got != want {
|
||
|
|
t.Fatalf("accepted metadata[role] = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
t.Fatalf("CopyStreamToWriter failed: %v", err)
|
||
|
|
case got := <-payloadCh:
|
||
|
|
if !bytes.Equal(got, payload) {
|
||
|
|
t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload))
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for copied payload")
|
||
|
|
}
|
||
|
|
}
|