Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
4e17fee681
|
|||
|
a8eed30db5
|
|||
| c1eaf43058 | |||
| 9f5aca124d | |||
| 54958724e7 | |||
| 7a17672149 | |||
| 44b807d3d1 | |||
| 0d847462b3 |
@@ -0,0 +1 @@
|
||||
.idea
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -108,6 +109,186 @@ type Request struct {
|
||||
RequestOpts
|
||||
}
|
||||
|
||||
func (r *Request) Clone() *Request {
|
||||
clonedRequest := &Request{
|
||||
ctx: r.ctx,
|
||||
uri: r.uri,
|
||||
method: r.method,
|
||||
errInfo: r.errInfo,
|
||||
RequestOpts: RequestOpts{
|
||||
headers: CloneHeader(r.headers),
|
||||
cookies: CloneCookies(r.cookies),
|
||||
bodyFormData: CloneStringMapSlice(r.bodyFormData),
|
||||
bodyFileData: CloneFiles(r.bodyFileData),
|
||||
queries: CloneStringMapSlice(r.queries),
|
||||
bodyDataBytes: CloneByteSlice(r.bodyDataBytes),
|
||||
proxy: r.proxy,
|
||||
timeout: r.timeout,
|
||||
dialTimeout: r.dialTimeout,
|
||||
dialFn: r.dialFn,
|
||||
alreadyApply: r.alreadyApply,
|
||||
disableRedirect: r.disableRedirect,
|
||||
doRawRequest: r.doRawRequest,
|
||||
doRawClient: r.doRawClient,
|
||||
doRawTransport: r.doRawTransport,
|
||||
skipTLSVerify: r.skipTLSVerify,
|
||||
autoFetchRespBody: r.autoFetchRespBody,
|
||||
customIP: CloneStringSlice(r.customIP),
|
||||
alreadySetLookUpIPfn: r.alreadySetLookUpIPfn,
|
||||
lookUpIPfn: r.lookUpIPfn,
|
||||
customDNS: CloneStringSlice(r.customDNS),
|
||||
basicAuth: r.basicAuth,
|
||||
autoCalcContentLength: r.autoCalcContentLength,
|
||||
},
|
||||
}
|
||||
|
||||
// 手动深拷贝嵌套引用类型
|
||||
if r.bodyDataReader != nil {
|
||||
clonedRequest.bodyDataReader = r.bodyDataReader
|
||||
}
|
||||
|
||||
if r.FileUploadRecallFn != nil {
|
||||
clonedRequest.FileUploadRecallFn = r.FileUploadRecallFn
|
||||
}
|
||||
|
||||
// 对于 tlsConfig 类型,需要手动复制
|
||||
if r.tlsConfig != nil {
|
||||
clonedRequest.tlsConfig = CloneTLSConfig(r.tlsConfig)
|
||||
}
|
||||
|
||||
// 对于 http.Transport,需要进行手动复制
|
||||
if r.transport != nil {
|
||||
clonedRequest.transport = CloneTransport(r.transport)
|
||||
}
|
||||
|
||||
return clonedRequest
|
||||
}
|
||||
|
||||
// CloneHeader 复制 http.Header
|
||||
func CloneHeader(original http.Header) http.Header {
|
||||
newHeader := make(http.Header)
|
||||
for key, values := range original {
|
||||
copiedValues := make([]string, len(values))
|
||||
copy(copiedValues, values)
|
||||
newHeader[key] = copiedValues
|
||||
}
|
||||
return newHeader
|
||||
}
|
||||
|
||||
// CloneCookies 复制 []*http.Cookie
|
||||
func CloneCookies(original []*http.Cookie) []*http.Cookie {
|
||||
cloned := make([]*http.Cookie, len(original))
|
||||
for i, cookie := range original {
|
||||
cloned[i] = &http.Cookie{
|
||||
Name: cookie.Name,
|
||||
Value: cookie.Value,
|
||||
Path: cookie.Path,
|
||||
Domain: cookie.Domain,
|
||||
Expires: cookie.Expires,
|
||||
RawExpires: cookie.RawExpires,
|
||||
MaxAge: cookie.MaxAge,
|
||||
Secure: cookie.Secure,
|
||||
HttpOnly: cookie.HttpOnly,
|
||||
SameSite: cookie.SameSite,
|
||||
Raw: cookie.Raw,
|
||||
Unparsed: append([]string(nil), cookie.Unparsed...),
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// CloneStringMapSlice 复制 map[string][]string
|
||||
func CloneStringMapSlice(original map[string][]string) map[string][]string {
|
||||
newMap := make(map[string][]string)
|
||||
for key, values := range original {
|
||||
copiedValues := make([]string, len(values))
|
||||
copy(copiedValues, values)
|
||||
newMap[key] = copiedValues
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
// CloneFiles 复制 []RequestFile
|
||||
func CloneFiles(original []RequestFile) []RequestFile {
|
||||
newFiles := make([]RequestFile, len(original))
|
||||
copy(newFiles, original)
|
||||
return newFiles
|
||||
}
|
||||
|
||||
// CloneByteSlice 复制 []byte
|
||||
func CloneByteSlice(original []byte) []byte {
|
||||
if original == nil {
|
||||
return nil
|
||||
}
|
||||
newSlice := make([]byte, len(original))
|
||||
copy(newSlice, original)
|
||||
return newSlice
|
||||
}
|
||||
|
||||
// CloneStringSlice 复制 []string
|
||||
func CloneStringSlice(original []string) []string {
|
||||
newSlice := make([]string, len(original))
|
||||
copy(newSlice, original)
|
||||
return newSlice
|
||||
}
|
||||
|
||||
// CloneTLSConfig 复制 tls.Config
|
||||
func CloneTLSConfig(original *tls.Config) *tls.Config {
|
||||
newConfig := &tls.Config{
|
||||
Rand: original.Rand,
|
||||
Time: original.Time,
|
||||
Certificates: append([]tls.Certificate(nil), original.Certificates...),
|
||||
NameToCertificate: original.NameToCertificate,
|
||||
GetCertificate: original.GetCertificate,
|
||||
GetClientCertificate: original.GetClientCertificate,
|
||||
GetConfigForClient: original.GetConfigForClient,
|
||||
VerifyPeerCertificate: original.VerifyPeerCertificate,
|
||||
VerifyConnection: original.VerifyConnection,
|
||||
RootCAs: original.RootCAs,
|
||||
NextProtos: append([]string(nil), original.NextProtos...),
|
||||
ServerName: original.ServerName,
|
||||
ClientAuth: original.ClientAuth,
|
||||
ClientCAs: original.ClientCAs,
|
||||
InsecureSkipVerify: original.InsecureSkipVerify,
|
||||
CipherSuites: append([]uint16(nil), original.CipherSuites...),
|
||||
PreferServerCipherSuites: original.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: original.SessionTicketsDisabled,
|
||||
SessionTicketKey: original.SessionTicketKey,
|
||||
ClientSessionCache: original.ClientSessionCache,
|
||||
MinVersion: original.MinVersion,
|
||||
MaxVersion: original.MaxVersion,
|
||||
CurvePreferences: append([]tls.CurveID(nil), original.CurvePreferences...),
|
||||
DynamicRecordSizingDisabled: original.DynamicRecordSizingDisabled,
|
||||
Renegotiation: original.Renegotiation,
|
||||
KeyLogWriter: original.KeyLogWriter,
|
||||
}
|
||||
return newConfig
|
||||
}
|
||||
|
||||
// CloneTransport 复制 http.Transport
|
||||
func CloneTransport(original *http.Transport) *http.Transport {
|
||||
newTransport := &http.Transport{
|
||||
Proxy: original.Proxy,
|
||||
DialContext: original.DialContext,
|
||||
Dial: original.Dial,
|
||||
DialTLS: original.DialTLS,
|
||||
TLSClientConfig: original.TLSClientConfig,
|
||||
TLSHandshakeTimeout: original.TLSHandshakeTimeout,
|
||||
DisableKeepAlives: original.DisableKeepAlives,
|
||||
DisableCompression: original.DisableCompression,
|
||||
MaxIdleConns: original.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: original.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: original.IdleConnTimeout,
|
||||
ResponseHeaderTimeout: original.ResponseHeaderTimeout,
|
||||
ExpectContinueTimeout: original.ExpectContinueTimeout,
|
||||
TLSNextProto: original.TLSNextProto,
|
||||
ProxyConnectHeader: original.ProxyConnectHeader,
|
||||
MaxResponseHeaderBytes: original.MaxResponseHeaderBytes,
|
||||
WriteBufferSize: original.WriteBufferSize,
|
||||
ReadBufferSize: original.ReadBufferSize,
|
||||
}
|
||||
return newTransport
|
||||
}
|
||||
func (r *Request) Method() string {
|
||||
return r.method
|
||||
}
|
||||
@@ -202,6 +383,7 @@ type RequestOpts struct {
|
||||
proxy string
|
||||
timeout time.Duration
|
||||
dialTimeout time.Duration
|
||||
dialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
headers http.Header
|
||||
cookies []*http.Cookie
|
||||
transport *http.Transport
|
||||
@@ -224,6 +406,14 @@ type RequestOpts struct {
|
||||
autoCalcContentLength bool
|
||||
}
|
||||
|
||||
func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return r.dialFn
|
||||
}
|
||||
|
||||
func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) {
|
||||
r.dialFn = dialFn
|
||||
}
|
||||
|
||||
func (r *Request) AutoCalcContentLength() bool {
|
||||
return r.autoCalcContentLength
|
||||
}
|
||||
@@ -665,6 +855,14 @@ func (r *Request) AddFileWithNameAndTypeNoError(formName, filepath, filename, fi
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Request) HttpClient() (*http.Client, error) {
|
||||
err := applyOptions(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.rawClient, nil
|
||||
}
|
||||
|
||||
type RequestFile struct {
|
||||
FormName string
|
||||
FileName string
|
||||
@@ -683,6 +881,14 @@ func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||||
}
|
||||
}
|
||||
|
||||
// if doRawTransport is true, this function will nolonger work
|
||||
func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt {
|
||||
return func(opt *RequestOpts) error {
|
||||
opt.dialFn = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// if doRawTransport is true, this function will nolonger work
|
||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||
return func(opt *RequestOpts) error {
|
||||
@@ -1065,16 +1271,24 @@ type Response struct {
|
||||
*http.Response
|
||||
req Request
|
||||
data *Body
|
||||
rawClient *http.Client
|
||||
}
|
||||
|
||||
type Body struct {
|
||||
full []byte
|
||||
raw io.ReadCloser
|
||||
isFull bool
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (b *Body) readAll() {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
if !b.isFull {
|
||||
if b.raw == nil {
|
||||
b.isFull = true
|
||||
return
|
||||
}
|
||||
b.full, _ = io.ReadAll(b.raw)
|
||||
b.isFull = true
|
||||
b.raw.Close()
|
||||
@@ -1099,6 +1313,8 @@ func (b *Body) Unmarshal(u interface{}) error {
|
||||
// Reader returns a reader for the body
|
||||
// if this function is called, other functions like String, Bytes, Unmarshal may not work
|
||||
func (b *Body) Reader() io.ReadCloser {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
if b.isFull {
|
||||
return io.NopCloser(bytes.NewReader(b.full))
|
||||
}
|
||||
@@ -1118,6 +1334,24 @@ func (r *Response) Body() *Body {
|
||||
return r.data
|
||||
}
|
||||
|
||||
func (r *Response) Close() error {
|
||||
if r != nil && r.data != nil && r.data.raw != nil {
|
||||
return r.Response.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Response) CloseAll() error {
|
||||
if r.rawClient != nil {
|
||||
r.rawClient.CloseIdleConnections()
|
||||
}
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
func (r *Response) HttpClient() *http.Client {
|
||||
return r.rawClient
|
||||
}
|
||||
|
||||
func Curl(r *Request) (*Response, error) {
|
||||
r.errInfo = nil
|
||||
err := applyOptions(r)
|
||||
@@ -1129,6 +1363,7 @@ func Curl(r *Request) (*Response, error) {
|
||||
Response: resp,
|
||||
req: *r,
|
||||
data: new(Body),
|
||||
rawClient: r.rawClient,
|
||||
}
|
||||
if err != nil {
|
||||
res.Response = &http.Response{}
|
||||
@@ -1169,6 +1404,17 @@ func NewRequestWithContext(ctx context.Context, uri string, method string, opts
|
||||
return newRequest(ctx, uri, method, opts...)
|
||||
}
|
||||
|
||||
func NewHttpClient(opts ...RequestOpt) (*http.Client, error) {
|
||||
req, err := newRequest(context.Background(), "", "", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
req = nil
|
||||
}()
|
||||
return req.HttpClient()
|
||||
}
|
||||
|
||||
func newRequest(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) {
|
||||
var req *http.Request
|
||||
var err error
|
||||
@@ -1186,7 +1432,7 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
|
||||
method: method,
|
||||
RequestOpts: RequestOpts{
|
||||
rawRequest: req,
|
||||
rawClient: new(http.Client),
|
||||
rawClient: nil,
|
||||
timeout: DefaultTimeout,
|
||||
dialTimeout: DefaultDialTimeout,
|
||||
autoFetchRespBody: DefaultFetchRespBody,
|
||||
@@ -1209,10 +1455,13 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.transport == nil {
|
||||
if r.transport == nil && !r.doRawTransport {
|
||||
r.transport = &http.Transport{}
|
||||
}
|
||||
if r.doRawTransport {
|
||||
if r.rawClient == nil && !r.doRawClient {
|
||||
r.rawClient = new(http.Client)
|
||||
}
|
||||
if !r.doRawTransport {
|
||||
if r.skipTLSVerify {
|
||||
if r.transport.TLSClientConfig == nil {
|
||||
r.transport.TLSClientConfig = &tls.Config{}
|
||||
@@ -1261,6 +1510,9 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
if r.dialFn != nil {
|
||||
r.transport.DialContext = r.dialFn
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
@@ -1397,6 +1649,18 @@ func applyOptions(r *Request) error {
|
||||
r.lookUpIPfn = resolver.LookupIPAddr
|
||||
}
|
||||
}
|
||||
if r.skipTLSVerify {
|
||||
if r.transport.TLSClientConfig == nil {
|
||||
r.transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
r.transport.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
if r.tlsConfig != nil {
|
||||
r.transport.TLSClientConfig = r.tlsConfig
|
||||
}
|
||||
if r.dialFn != nil {
|
||||
r.transport.DialContext = r.dialFn
|
||||
}
|
||||
r.rawClient.Transport = r.transport
|
||||
if r.disableRedirect {
|
||||
r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
@@ -1448,3 +1712,36 @@ func copyWithContext(ctx context.Context, recall func(string, int64, int64), fil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewReqWithClient(client *http.Client, uri string, opts ...RequestOpt) *Request {
|
||||
return NewSimpleRequestWithClient(client, uri, "GET", opts...)
|
||||
}
|
||||
|
||||
func NewReqWithContextWithClient(ctx context.Context, client *http.Client, uri string, opts ...RequestOpt) *Request {
|
||||
return NewSimpleRequestWithContextWithClient(ctx, client, uri, "GET", opts...)
|
||||
}
|
||||
|
||||
func NewSimpleRequestWithClient(client *http.Client, uri string, method string, opts ...RequestOpt) *Request {
|
||||
r, _ := NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...)
|
||||
return r
|
||||
}
|
||||
|
||||
func NewRequestWithClient(client *http.Client, uri string, method string, opts ...RequestOpt) (*Request, error) {
|
||||
return NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...)
|
||||
}
|
||||
|
||||
func NewSimpleRequestWithContextWithClient(ctx context.Context, client *http.Client, uri string, method string, opts ...RequestOpt) *Request {
|
||||
r, _ := NewRequestWithContextWithClient(ctx, client, uri, method, opts...)
|
||||
return r
|
||||
}
|
||||
|
||||
func NewRequestWithContextWithClient(ctx context.Context, client *http.Client, uri string, method string, opts ...RequestOpt) (*Request, error) {
|
||||
req, err := newRequest(ctx, uri, method, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.rawClient = client
|
||||
req.SetDoRawClient(true)
|
||||
req.SetDoRawTransport(true)
|
||||
return req, err
|
||||
}
|
||||
|
||||
+401
@@ -0,0 +1,401 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type myConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
isReadOnly bool
|
||||
multiReader io.Reader
|
||||
}
|
||||
|
||||
func (c *myConn) Read(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
if c.multiReader == nil {
|
||||
c.multiReader = io.MultiReader(c.reader, c.conn)
|
||||
}
|
||||
return c.multiReader.Read(p)
|
||||
}
|
||||
|
||||
func (c *myConn) Write(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
func (c *myConn) Close() error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.Close()
|
||||
}
|
||||
func (c *myConn) LocalAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
func (c *myConn) RemoteAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
func (c *myConn) SetDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
cfg *tls.Config
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
||||
return l.getConfigForClient
|
||||
}
|
||||
|
||||
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
||||
l.getConfigForClient = getConfigForClient
|
||||
}
|
||||
|
||||
func Listen(network, address string) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: listener}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := liscfg.Listen(context.Background(), network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: tlsConfig,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
tlsCfg: l.cfg,
|
||||
getConfigForClient: l.getConfigForClient,
|
||||
allowNonTls: l.allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
once sync.Once
|
||||
initErr error
|
||||
isTLS bool
|
||||
tlsCfg *tls.Config
|
||||
tlsConn *tls.Conn
|
||||
buffer *bytes.Buffer
|
||||
noTlsReader io.Reader
|
||||
isOriginal bool
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
hostname string
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (c *Conn) Hostname() string {
|
||||
if c.hostname != "" {
|
||||
return c.hostname
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
if c.tlsConn.ConnectionState().ServerName != "" {
|
||||
c.hostname = c.tlsConn.ConnectionState().ServerName
|
||||
return c.hostname
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Conn) IsTLS() bool {
|
||||
return c.isTLS
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConn() *tls.Conn {
|
||||
return c.tlsConn
|
||||
}
|
||||
|
||||
func (c *Conn) isTLSConnection() (bool, error) {
|
||||
if c.getConfigForClient == nil {
|
||||
peek := make([]byte, 5)
|
||||
n, err := io.ReadFull(c.Conn, peek)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
|
||||
c.buffer = bytes.NewBuffer(peek[:n])
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
c.buffer = new(bytes.Buffer)
|
||||
r := io.TeeReader(c.Conn, c.buffer)
|
||||
var hello *tls.ClientHelloInfo
|
||||
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
peek := c.buffer.Bytes()
|
||||
n := len(peek)
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
if hello == nil {
|
||||
return isTLS, nil
|
||||
}
|
||||
c.hostname = hello.ServerName
|
||||
if c.hostname == "" {
|
||||
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
|
||||
}
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
func (c *Conn) init() {
|
||||
c.once.Do(func() {
|
||||
if c.isOriginal {
|
||||
return
|
||||
}
|
||||
if c.tlsCfg != nil {
|
||||
isTLS, err := c.isTLSConnection()
|
||||
if err != nil {
|
||||
c.initErr = err
|
||||
return
|
||||
}
|
||||
c.isTLS = isTLS
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
var cfg = c.tlsCfg
|
||||
if c.getConfigForClient != nil {
|
||||
cfg = c.getConfigForClient(c.hostname)
|
||||
if cfg == nil {
|
||||
cfg = c.tlsCfg
|
||||
}
|
||||
}
|
||||
c.tlsConn = tls.Server(&myConn{
|
||||
reader: c.buffer,
|
||||
conn: c.Conn,
|
||||
isReadOnly: false,
|
||||
}, cfg)
|
||||
} else {
|
||||
if !c.allowNonTls {
|
||||
c.initErr = net.ErrClosed
|
||||
return
|
||||
}
|
||||
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Read(b)
|
||||
}
|
||||
return c.noTlsReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Write(b)
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.Close()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetDeadline(t)
|
||||
}
|
||||
return c.Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetReadDeadline(t)
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetWriteDeadline(t)
|
||||
}
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
||||
if c.initErr != nil {
|
||||
return nil, c.initErr
|
||||
}
|
||||
if !c.isTLS {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c.tlsConn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) OriginalConn() net.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Client(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Server(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
c.init()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Dial(network, address string) (*Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
isTLS: false,
|
||||
tlsCfg: nil,
|
||||
tlsConn: nil,
|
||||
noTlsReader: conn,
|
||||
isOriginal: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientTlsConn(conn, tlsConfig)
|
||||
}
|
||||
Reference in New Issue
Block a user