package whois import ( "context" "errors" "net" "sync" "testing" "time" ) func TestDialContextCancelClosesLateConnection(t *testing.T) { left, right := net.Pipe() defer right.Close() spyConn := &closeSpyConn{ Conn: left, closed: make(chan struct{}), } dialer := &blockingDialer{ ready: make(chan struct{}), release: make(chan struct{}), conn: spyConn, } c := NewClient().SetDialer(dialer) ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) defer cancel() errCh := make(chan error, 1) go func() { _, err := c.dialContext(ctx, "tcp", "example.com:43") errCh <- err }() <-dialer.ready err := <-errCh if err == nil { t.Fatal("expected context cancellation error") } if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { t.Fatalf("unexpected error: %v", err) } close(dialer.release) select { case <-spyConn.closed: case <-time.After(time.Second): t.Fatal("expected late connection to be closed after context cancel") } } type blockingDialer struct { ready chan struct{} release chan struct{} conn net.Conn err error once sync.Once } func (d *blockingDialer) Dial(_ string, _ string) (net.Conn, error) { d.once.Do(func() { close(d.ready) }) <-d.release return d.conn, d.err } type closeSpyConn struct { net.Conn closed chan struct{} once sync.Once } func (c *closeSpyConn) Close() error { c.once.Do(func() { close(c.closed) }) return c.Conn.Close() }