79 lines
1.5 KiB
Go
79 lines
1.5 KiB
Go
|
|
package whois
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"net"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestDialContextCancelClosesLateConnection(t *testing.T) {
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
spyConn := &closeSpyConn{
|
||
|
|
Conn: left,
|
||
|
|
closed: make(chan struct{}),
|
||
|
|
}
|
||
|
|
dialer := &blockingDialer{
|
||
|
|
ready: make(chan struct{}),
|
||
|
|
release: make(chan struct{}),
|
||
|
|
conn: spyConn,
|
||
|
|
}
|
||
|
|
c := NewClient().SetDialer(dialer)
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
_, err := c.dialContext(ctx, "tcp", "example.com:43")
|
||
|
|
errCh <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
<-dialer.ready
|
||
|
|
err := <-errCh
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected context cancellation error")
|
||
|
|
}
|
||
|
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
close(dialer.release)
|
||
|
|
select {
|
||
|
|
case <-spyConn.closed:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("expected late connection to be closed after context cancel")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type blockingDialer struct {
|
||
|
|
ready chan struct{}
|
||
|
|
release chan struct{}
|
||
|
|
conn net.Conn
|
||
|
|
err error
|
||
|
|
once sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func (d *blockingDialer) Dial(_ string, _ string) (net.Conn, error) {
|
||
|
|
d.once.Do(func() { close(d.ready) })
|
||
|
|
<-d.release
|
||
|
|
return d.conn, d.err
|
||
|
|
}
|
||
|
|
|
||
|
|
type closeSpyConn struct {
|
||
|
|
net.Conn
|
||
|
|
closed chan struct{}
|
||
|
|
once sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *closeSpyConn) Close() error {
|
||
|
|
c.once.Do(func() {
|
||
|
|
close(c.closed)
|
||
|
|
})
|
||
|
|
return c.Conn.Close()
|
||
|
|
}
|