whoissdk/client_dial_test.go

79 lines
1.5 KiB
Go
Raw Normal View History

2026-03-19 11:53:07 +08:00
package whois
import (
"context"
"errors"
"net"
"sync"
"testing"
"time"
)
func TestDialContextCancelClosesLateConnection(t *testing.T) {
left, right := net.Pipe()
defer right.Close()
spyConn := &closeSpyConn{
Conn: left,
closed: make(chan struct{}),
}
dialer := &blockingDialer{
ready: make(chan struct{}),
release: make(chan struct{}),
conn: spyConn,
}
c := NewClient().SetDialer(dialer)
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
defer cancel()
errCh := make(chan error, 1)
go func() {
_, err := c.dialContext(ctx, "tcp", "example.com:43")
errCh <- err
}()
<-dialer.ready
err := <-errCh
if err == nil {
t.Fatal("expected context cancellation error")
}
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error: %v", err)
}
close(dialer.release)
select {
case <-spyConn.closed:
case <-time.After(time.Second):
t.Fatal("expected late connection to be closed after context cancel")
}
}
type blockingDialer struct {
ready chan struct{}
release chan struct{}
conn net.Conn
err error
once sync.Once
}
func (d *blockingDialer) Dial(_ string, _ string) (net.Conn, error) {
d.once.Do(func() { close(d.ready) })
<-d.release
return d.conn, d.err
}
type closeSpyConn struct {
net.Conn
closed chan struct{}
once sync.Once
}
func (c *closeSpyConn) Close() error {
c.once.Do(func() {
close(c.closed)
})
return c.Conn.Close()
}