package starnet import ( "bytes" "crypto/tls" "io" "net" "sync" "time" ) type myConn struct { reader io.Reader conn net.Conn isReadOnly bool multiReader io.Reader } func (c *myConn) Read(p []byte) (int, error) { if c.isReadOnly { return c.reader.Read(p) } if c.multiReader == nil { c.multiReader = io.MultiReader(c.reader, c.conn) } return c.multiReader.Read(p) } func (c *myConn) Write(p []byte) (int, error) { if c.isReadOnly { return 0, io.ErrClosedPipe } return c.conn.Write(p) } func (c *myConn) Close() error { if c.isReadOnly { return nil } return c.conn.Close() } func (c *myConn) LocalAddr() net.Addr { if c.isReadOnly { return nil } return c.conn.LocalAddr() } func (c *myConn) RemoteAddr() net.Addr { if c.isReadOnly { return nil } return c.conn.RemoteAddr() } func (c *myConn) SetDeadline(t time.Time) error { if c.isReadOnly { return nil } return c.conn.SetDeadline(t) } func (c *myConn) SetReadDeadline(t time.Time) error { if c.isReadOnly { return nil } return c.conn.SetReadDeadline(t) } func (c *myConn) SetWriteDeadline(t time.Time) error { if c.isReadOnly { return nil } return c.conn.SetWriteDeadline(t) } type Listener struct { net.Listener cfg *tls.Config getConfigForClient func(hostname string) *tls.Config allowNonTls bool } func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config { return l.getConfigForClient } func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) { l.getConfigForClient = getConfigForClient } func Listen(network, address string) (*Listener, error) { listener, err := net.Listen(network, address) if err != nil { return nil, err } return &Listener{Listener: listener}, nil } func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) { listener, err := net.Listen(network, address) if err != nil { return nil, err } return &Listener{ Listener: listener, cfg: config, getConfigForClient: getConfigForClient, allowNonTls: allowNonTls, }, nil } func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) { config, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } tlsConfig := &tls.Config{ Certificates: []tls.Certificate{config}, } listener, err := net.Listen(network, address) if err != nil { return nil, err } return &Listener{ Listener: listener, cfg: tlsConfig, allowNonTls: allowNonTls, }, nil } func (l *Listener) Accept() (net.Conn, error) { conn, err := l.Listener.Accept() if err != nil { return nil, err } return &Conn{ Conn: conn, tlsCfg: l.cfg, getConfigForClient: l.getConfigForClient, allowNonTls: l.allowNonTls, }, nil } type Conn struct { net.Conn once sync.Once initErr error isTLS bool tlsCfg *tls.Config tlsConn *tls.Conn buffer *bytes.Buffer noTlsReader io.Reader isOriginal bool getConfigForClient func(hostname string) *tls.Config hostname string allowNonTls bool } func (c *Conn) Hostname() string { if c.hostname != "" { return c.hostname } if c.isTLS && c.tlsConn != nil { if c.tlsConn.ConnectionState().ServerName != "" { c.hostname = c.tlsConn.ConnectionState().ServerName return c.hostname } } return "" } func (c *Conn) IsTLS() bool { return c.isTLS } func (c *Conn) TlsConn() *tls.Conn { return c.tlsConn } func (c *Conn) isTLSConnection() (bool, error) { if c.getConfigForClient == nil { peek := make([]byte, 5) n, err := io.ReadFull(c.Conn, peek) if err != nil { return false, err } isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 c.buffer = bytes.NewBuffer(peek[:n]) return isTLS, nil } c.buffer = new(bytes.Buffer) r := io.TeeReader(c.Conn, c.buffer) var hello *tls.ClientHelloInfo tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{ GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { hello = new(tls.ClientHelloInfo) *hello = *argHello return nil, nil }, }).Handshake() peek := c.buffer.Bytes() n := len(peek) isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 if hello == nil { return isTLS, nil } c.hostname = hello.ServerName if c.hostname == "" { c.hostname, _, _ = net.SplitHostPort(c.Conn.RemoteAddr().String()) } return isTLS, nil } func (c *Conn) init() { c.once.Do(func() { if c.isOriginal { return } if c.tlsCfg != nil { isTLS, err := c.isTLSConnection() if err != nil { c.initErr = err return } c.isTLS = isTLS } if c.isTLS { var cfg = c.tlsCfg if c.getConfigForClient != nil { cfg = c.getConfigForClient(c.hostname) if cfg == nil { cfg = c.tlsCfg } } c.tlsConn = tls.Server(&myConn{ reader: c.buffer, conn: c.Conn, isReadOnly: false, }, cfg) } else { if !c.allowNonTls { c.initErr = net.ErrClosed return } c.noTlsReader = io.MultiReader(c.buffer, c.Conn) } }) } func (c *Conn) Read(b []byte) (int, error) { c.init() if c.initErr != nil { return 0, c.initErr } if c.isTLS { return c.tlsConn.Read(b) } return c.noTlsReader.Read(b) } func (c *Conn) Write(b []byte) (int, error) { c.init() if c.initErr != nil { return 0, c.initErr } if c.isTLS { return c.tlsConn.Write(b) } return c.Conn.Write(b) } func (c *Conn) Close() error { if c.isTLS && c.tlsConn != nil { return c.tlsConn.Close() } return c.Conn.Close() } func (c *Conn) SetDeadline(t time.Time) error { if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetDeadline(t) } return c.Conn.SetDeadline(t) } func (c *Conn) SetReadDeadline(t time.Time) error { if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetReadDeadline(t) } return c.Conn.SetReadDeadline(t) } func (c *Conn) SetWriteDeadline(t time.Time) error { if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetWriteDeadline(t) } return c.Conn.SetWriteDeadline(t) } func (c *Conn) TlsConnection() (*tls.Conn, error) { if c.initErr != nil { return nil, c.initErr } if !c.isTLS { return nil, net.ErrClosed } return c.tlsConn, nil } func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { if conn == nil { return nil, net.ErrClosed } c := &Conn{ Conn: conn, isTLS: true, tlsCfg: cfg, tlsConn: tls.Client(conn, cfg), isOriginal: true, } return c, nil } func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { if conn == nil { return nil, net.ErrClosed } c := &Conn{ Conn: conn, isTLS: true, tlsCfg: cfg, tlsConn: tls.Server(conn, cfg), isOriginal: true, } c.init() return c, nil } func Dial(network, address string) (*Conn, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } return &Conn{ Conn: conn, isTLS: false, tlsCfg: nil, tlsConn: nil, noTlsReader: conn, isOriginal: true, }, nil } func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) { config, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } tlsConfig := &tls.Config{ Certificates: []tls.Certificate{config}, } conn, err := net.Dial(network, address) if err != nil { return nil, err } return NewClientTlsConn(conn, tlsConfig) }