402 lines
8.3 KiB
Go
402 lines
8.3 KiB
Go
package starnet
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type myConn struct {
|
|
reader io.Reader
|
|
conn net.Conn
|
|
isReadOnly bool
|
|
multiReader io.Reader
|
|
}
|
|
|
|
func (c *myConn) Read(p []byte) (int, error) {
|
|
if c.isReadOnly {
|
|
return c.reader.Read(p)
|
|
}
|
|
if c.multiReader == nil {
|
|
c.multiReader = io.MultiReader(c.reader, c.conn)
|
|
}
|
|
return c.multiReader.Read(p)
|
|
}
|
|
|
|
func (c *myConn) Write(p []byte) (int, error) {
|
|
if c.isReadOnly {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
return c.conn.Write(p)
|
|
}
|
|
func (c *myConn) Close() error {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.Close()
|
|
}
|
|
func (c *myConn) LocalAddr() net.Addr {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.LocalAddr()
|
|
}
|
|
func (c *myConn) RemoteAddr() net.Addr {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.RemoteAddr()
|
|
}
|
|
func (c *myConn) SetDeadline(t time.Time) error {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.SetDeadline(t)
|
|
}
|
|
func (c *myConn) SetReadDeadline(t time.Time) error {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.SetReadDeadline(t)
|
|
}
|
|
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
|
if c.isReadOnly {
|
|
return nil
|
|
}
|
|
return c.conn.SetWriteDeadline(t)
|
|
}
|
|
|
|
type Listener struct {
|
|
net.Listener
|
|
cfg *tls.Config
|
|
getConfigForClient func(hostname string) *tls.Config
|
|
allowNonTls bool
|
|
}
|
|
|
|
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
|
return l.getConfigForClient
|
|
}
|
|
|
|
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
|
l.getConfigForClient = getConfigForClient
|
|
}
|
|
|
|
func Listen(network, address string) (*Listener, error) {
|
|
listener, err := net.Listen(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Listener{Listener: listener}, nil
|
|
}
|
|
|
|
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
|
listener, err := liscfg.Listen(context.Background(), network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Listener{
|
|
Listener: listener,
|
|
cfg: config,
|
|
getConfigForClient: getConfigForClient,
|
|
allowNonTls: allowNonTls,
|
|
}, nil
|
|
}
|
|
|
|
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
|
return &Listener{
|
|
Listener: listener,
|
|
cfg: config,
|
|
getConfigForClient: getConfigForClient,
|
|
allowNonTls: allowNonTls,
|
|
}, nil
|
|
}
|
|
|
|
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
|
listener, err := net.Listen(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Listener{
|
|
Listener: listener,
|
|
cfg: config,
|
|
getConfigForClient: getConfigForClient,
|
|
allowNonTls: allowNonTls,
|
|
}, nil
|
|
}
|
|
|
|
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
|
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{config},
|
|
}
|
|
|
|
listener, err := net.Listen(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Listener{
|
|
Listener: listener,
|
|
cfg: tlsConfig,
|
|
allowNonTls: allowNonTls,
|
|
}, nil
|
|
}
|
|
|
|
func (l *Listener) Accept() (net.Conn, error) {
|
|
conn, err := l.Listener.Accept()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Conn{
|
|
Conn: conn,
|
|
tlsCfg: l.cfg,
|
|
getConfigForClient: l.getConfigForClient,
|
|
allowNonTls: l.allowNonTls,
|
|
}, nil
|
|
}
|
|
|
|
type Conn struct {
|
|
net.Conn
|
|
once sync.Once
|
|
initErr error
|
|
isTLS bool
|
|
tlsCfg *tls.Config
|
|
tlsConn *tls.Conn
|
|
buffer *bytes.Buffer
|
|
noTlsReader io.Reader
|
|
isOriginal bool
|
|
getConfigForClient func(hostname string) *tls.Config
|
|
hostname string
|
|
allowNonTls bool
|
|
}
|
|
|
|
func (c *Conn) Hostname() string {
|
|
if c.hostname != "" {
|
|
return c.hostname
|
|
}
|
|
if c.isTLS && c.tlsConn != nil {
|
|
if c.tlsConn.ConnectionState().ServerName != "" {
|
|
c.hostname = c.tlsConn.ConnectionState().ServerName
|
|
return c.hostname
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (c *Conn) IsTLS() bool {
|
|
return c.isTLS
|
|
}
|
|
|
|
func (c *Conn) TlsConn() *tls.Conn {
|
|
return c.tlsConn
|
|
}
|
|
|
|
func (c *Conn) isTLSConnection() (bool, error) {
|
|
if c.getConfigForClient == nil {
|
|
peek := make([]byte, 5)
|
|
n, err := io.ReadFull(c.Conn, peek)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
|
|
|
c.buffer = bytes.NewBuffer(peek[:n])
|
|
return isTLS, nil
|
|
}
|
|
|
|
c.buffer = new(bytes.Buffer)
|
|
r := io.TeeReader(c.Conn, c.buffer)
|
|
var hello *tls.ClientHelloInfo
|
|
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
hello = new(tls.ClientHelloInfo)
|
|
*hello = *argHello
|
|
return nil, nil
|
|
},
|
|
}).Handshake()
|
|
peek := c.buffer.Bytes()
|
|
n := len(peek)
|
|
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
|
if hello == nil {
|
|
return isTLS, nil
|
|
}
|
|
c.hostname = hello.ServerName
|
|
if c.hostname == "" {
|
|
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
|
|
}
|
|
return isTLS, nil
|
|
}
|
|
|
|
func (c *Conn) init() {
|
|
c.once.Do(func() {
|
|
if c.isOriginal {
|
|
return
|
|
}
|
|
if c.tlsCfg != nil {
|
|
isTLS, err := c.isTLSConnection()
|
|
if err != nil {
|
|
c.initErr = err
|
|
return
|
|
}
|
|
c.isTLS = isTLS
|
|
}
|
|
|
|
if c.isTLS {
|
|
var cfg = c.tlsCfg
|
|
if c.getConfigForClient != nil {
|
|
cfg = c.getConfigForClient(c.hostname)
|
|
if cfg == nil {
|
|
cfg = c.tlsCfg
|
|
}
|
|
}
|
|
c.tlsConn = tls.Server(&myConn{
|
|
reader: c.buffer,
|
|
conn: c.Conn,
|
|
isReadOnly: false,
|
|
}, cfg)
|
|
} else {
|
|
if !c.allowNonTls {
|
|
c.initErr = net.ErrClosed
|
|
return
|
|
}
|
|
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (c *Conn) Read(b []byte) (int, error) {
|
|
c.init()
|
|
if c.initErr != nil {
|
|
return 0, c.initErr
|
|
}
|
|
if c.isTLS {
|
|
return c.tlsConn.Read(b)
|
|
}
|
|
return c.noTlsReader.Read(b)
|
|
}
|
|
|
|
func (c *Conn) Write(b []byte) (int, error) {
|
|
c.init()
|
|
if c.initErr != nil {
|
|
return 0, c.initErr
|
|
}
|
|
|
|
if c.isTLS {
|
|
return c.tlsConn.Write(b)
|
|
}
|
|
return c.Conn.Write(b)
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
if c.isTLS && c.tlsConn != nil {
|
|
return c.tlsConn.Close()
|
|
}
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *Conn) SetDeadline(t time.Time) error {
|
|
if c.isTLS && c.tlsConn != nil {
|
|
return c.tlsConn.SetDeadline(t)
|
|
}
|
|
return c.Conn.SetDeadline(t)
|
|
}
|
|
|
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
if c.isTLS && c.tlsConn != nil {
|
|
return c.tlsConn.SetReadDeadline(t)
|
|
}
|
|
return c.Conn.SetReadDeadline(t)
|
|
}
|
|
|
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
if c.isTLS && c.tlsConn != nil {
|
|
return c.tlsConn.SetWriteDeadline(t)
|
|
}
|
|
return c.Conn.SetWriteDeadline(t)
|
|
}
|
|
|
|
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
|
if c.initErr != nil {
|
|
return nil, c.initErr
|
|
}
|
|
if !c.isTLS {
|
|
return nil, net.ErrClosed
|
|
}
|
|
return c.tlsConn, nil
|
|
}
|
|
|
|
func (c *Conn) OriginalConn() net.Conn {
|
|
return c.Conn
|
|
}
|
|
|
|
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
|
if conn == nil {
|
|
return nil, net.ErrClosed
|
|
}
|
|
c := &Conn{
|
|
Conn: conn,
|
|
isTLS: true,
|
|
tlsCfg: cfg,
|
|
tlsConn: tls.Client(conn, cfg),
|
|
isOriginal: true,
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
|
if conn == nil {
|
|
return nil, net.ErrClosed
|
|
}
|
|
c := &Conn{
|
|
Conn: conn,
|
|
isTLS: true,
|
|
tlsCfg: cfg,
|
|
tlsConn: tls.Server(conn, cfg),
|
|
isOriginal: true,
|
|
}
|
|
c.init()
|
|
return c, nil
|
|
}
|
|
|
|
func Dial(network, address string) (*Conn, error) {
|
|
conn, err := net.Dial(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Conn{
|
|
Conn: conn,
|
|
isTLS: false,
|
|
tlsCfg: nil,
|
|
tlsConn: nil,
|
|
noTlsReader: conn,
|
|
isOriginal: true,
|
|
}, nil
|
|
}
|
|
|
|
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
|
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{config},
|
|
}
|
|
|
|
conn, err := net.Dial(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewClientTlsConn(conn, tlsConfig)
|
|
}
|