67 lines
2.1 KiB
Go
67 lines
2.1 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"encoding/binary"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type shortWriteBulkRecordConn struct {
|
||
|
|
maxPerWrite int
|
||
|
|
buf bytes.Buffer
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *shortWriteBulkRecordConn) Read([]byte) (int, error) { return 0, io.EOF }
|
||
|
|
|
||
|
|
func (c *shortWriteBulkRecordConn) Write(p []byte) (int, error) {
|
||
|
|
if len(p) == 0 {
|
||
|
|
return 0, nil
|
||
|
|
}
|
||
|
|
n := c.maxPerWrite
|
||
|
|
if n <= 0 || n > len(p) {
|
||
|
|
n = len(p)
|
||
|
|
}
|
||
|
|
_, _ = c.buf.Write(p[:n])
|
||
|
|
return n, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *shortWriteBulkRecordConn) Close() error { return nil }
|
||
|
|
func (c *shortWriteBulkRecordConn) LocalAddr() net.Addr { return shortWriteBulkRecordAddr("local") }
|
||
|
|
func (c *shortWriteBulkRecordConn) RemoteAddr() net.Addr { return shortWriteBulkRecordAddr("remote") }
|
||
|
|
func (c *shortWriteBulkRecordConn) SetDeadline(time.Time) error { return nil }
|
||
|
|
func (c *shortWriteBulkRecordConn) SetReadDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (c *shortWriteBulkRecordConn) SetWriteDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type shortWriteBulkRecordAddr string
|
||
|
|
|
||
|
|
func (a shortWriteBulkRecordAddr) Network() string { return "tcp" }
|
||
|
|
func (a shortWriteBulkRecordAddr) String() string { return string(a) }
|
||
|
|
|
||
|
|
func TestWriteBulkDedicatedRecordWithDeadlineHandlesShortWrite(t *testing.T) {
|
||
|
|
conn := &shortWriteBulkRecordConn{maxPerWrite: 3}
|
||
|
|
payload := []byte("abcdefghijklmnopqrstuvwxyz")
|
||
|
|
if err := writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}); err != nil {
|
||
|
|
t.Fatalf("writeBulkDedicatedRecordWithDeadline failed: %v", err)
|
||
|
|
}
|
||
|
|
raw := conn.buf.Bytes()
|
||
|
|
if got, want := len(raw), bulkDedicatedRecordHeaderLen+len(payload); got != want {
|
||
|
|
t.Fatalf("record length = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := string(raw[:4]); got != bulkDedicatedRecordMagic {
|
||
|
|
t.Fatalf("record magic = %q, want %q", got, bulkDedicatedRecordMagic)
|
||
|
|
}
|
||
|
|
if got, want := int(binary.BigEndian.Uint32(raw[4:8])), len(payload); got != want {
|
||
|
|
t.Fatalf("record payload length = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := raw[bulkDedicatedRecordHeaderLen:]; !bytes.Equal(got, payload) {
|
||
|
|
t.Fatalf("record payload mismatch")
|
||
|
|
}
|
||
|
|
}
|