nat function improve

master
兔子 10 months ago
parent d2feccf3b3
commit 1276d3b6dd

@ -26,7 +26,7 @@ func (h *ReverseConfig) Run() error {
} }
for key, proxy := range h.proxy { for key, proxy := range h.proxy {
h.httpmux.HandleFunc(key, func(writer http.ResponseWriter, request *http.Request) { h.httpmux.HandleFunc(key, func(writer http.ResponseWriter, request *http.Request) {
starlog.Infof("<%s> Req Path:%s Addr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent")) starlog.Infof("<%s> Req Path:%s ListenAddr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent"))
if !h.BasicAuth(writer, request) { if !h.BasicAuth(writer, request) {
h.SetResponseHeader(writer) h.SetResponseHeader(writer)

@ -0,0 +1,23 @@
package net
import (
"testing"
"time"
)
func TestNat(t *testing.T) {
var s = NatServer{
ListenAddr: "0.0.0.0:10020",
enableTCP: true,
}
var c = NatClient{
ServiceTarget: "139.199.163.65:80",
CmdTarget: "127.0.0.1:10020",
enableTCP: true,
}
go s.Run()
go c.Run()
for {
time.Sleep(time.Second * 20)
}
}

@ -1,27 +1,138 @@
package net package net
import ( import (
"b612.me/starlog"
"bytes"
"context"
"crypto/sha256"
"io"
"net" "net"
"sync" "sync"
"time"
) )
type SimpleNatClient struct { type NatClient struct {
mu sync.RWMutex mu sync.RWMutex
cmdTCPConn net.Conn cmdTCPConn net.Conn
cmdUDPConn *net.UDPAddr cmdUDPConn *net.UDPAddr
ServiceTarget string ServiceTarget string
CmdTarget string CmdTarget string
tcpAlived bool tcpAlived bool
DialTimeout int
enableTCP bool
enableUDP bool
Passwd string
stopCtx context.Context
stopFn context.CancelFunc
} }
func (s *SimpleNatClient) tcpCmdConn() net.Conn { func (s *NatClient) tcpCmdConn() net.Conn {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.cmdTCPConn return s.cmdTCPConn
} }
func (s *SimpleNatClient) tcpCmdConnAlived() bool { func (s *NatClient) tcpCmdConnAlived() bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.tcpAlived return s.tcpAlived
} }
func (s *NatClient) setTcpCmdConnAlived(v bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.tcpAlived = v
}
func (s *NatClient) Run() {
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
if s.DialTimeout == 0 {
s.DialTimeout = 10000
}
if s.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16]
}
if s.enableTCP {
s.runTcp()
}
}
func (s *NatClient) runTcp() error {
var err error
starlog.Noticeln("nat client tcp module start run")
for {
select {
case <-s.stopCtx.Done():
if s.cmdTCPConn != nil {
s.setTcpCmdConnAlived(false)
s.cmdTCPConn.Close()
return nil
}
case <-time.After(time.Millisecond * 1500):
}
if s.cmdTCPConn != nil && s.tcpCmdConnAlived() {
continue
}
s.cmdTCPConn, err = net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
starlog.Errorf("dail remote tcp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
s.cmdTCPConn = nil
continue
}
starlog.Infoln("dail remote tcp cmd server ok,remote:", s.CmdTarget)
s.tcpCmdConn().Write(MSG_CMD_HELLO)
s.setTcpCmdConnAlived(true)
go s.handleTcpCmdConn(s.tcpCmdConn())
}
}
func (s *NatClient) handleTcpCmdConn(conn net.Conn) {
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
starlog.Infoln("tcp cmd server read fail:", err)
conn.Close()
s.setTcpCmdConnAlived(false)
return
}
if bytes.Equal(header, MSG_CMD_HELLO_REPLY) {
continue
}
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
go s.newRemoteTcpConn()
}
if bytes.Equal(header, MSG_HEARTBEAT) {
_, err = conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
s.setTcpCmdConnAlived(false)
return
}
}
}
}
func (s *NatClient) newRemoteTcpConn() {
log := starlog.Std.NewFlag()
starlog.Infoln("recv request,create new tcp conn")
nconn, err := net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
log.Errorf("dail server tcp conn %v fail:%v\n", s.CmdTarget, err)
return
}
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil {
nconn.Close()
log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err)
return
}
cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err)
return
}
go io.Copy(cconn, nconn)
go io.Copy(nconn, cconn)
}

@ -1,40 +1,46 @@
package net package net
import ( import (
"b612.me/starlog"
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
// MSG_CMD_HELLO 控制链路主动链接参头 16byte // MSG_CMD_HELLO 控制链路主动链接参头 16byte
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA1") var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA2") var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF") var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct { type NatServer struct {
sync.RWMutex sync.RWMutex
cmdTCPConn net.Conn cmdTCPConn net.Conn
listenTcp net.Listener listenTcp net.Listener
listenUDP *net.UDPConn listenUDP *net.UDPConn
Addr string ListenAddr string
Port int
lastTCPHeart int64 lastTCPHeart int64
lastUDPHeart int64 lastUDPHeart int64
Passwd string Passwd string
DialTimeout int64 NetTimeout int64
UDPTimeout int64 UDPTimeout int64
running int32 running int32
tcpConnPool chan net.Conn tcpConnPool chan net.Conn
stopCtx context.Context stopCtx context.Context
stopFn context.CancelFunc stopFn context.CancelFunc
enableTCP bool
enableUDP bool
} }
func (n *NatServer) Run() error { func (n *NatServer) Run() error {
@ -42,48 +48,164 @@ func (n *NatServer) Run() error {
return fmt.Errorf("Server Already Run") return fmt.Errorf("Server Already Run")
} }
n.stopCtx, n.stopFn = context.WithCancel(context.Background()) n.stopCtx, n.stopFn = context.WithCancel(context.Background())
return nil if n.NetTimeout == 0 {
n.NetTimeout = 10000
} }
if n.Passwd != "" {
func (n *NatServer) cmdTcploop(conn net.Conn) error { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
var header = make([]byte, 16)
for {
c, err := conn.Read(header)
if err != nil {
//todo
} }
if c != 16 {
if n.enableTCP {
go n.runTcpListen()
} }
} return nil
} }
func (n *NatServer) runTcpListen() error { func (n *NatServer) runTcpListen() error {
var err error
n.tcpConnPool = make(chan net.Conn, 128)
atomic.AddInt32(&n.running, 1) atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1) defer atomic.AddInt32(&n.running, -1)
listener, err := net.Listen("tcp", n.Addr) starlog.Infoln("nat server tcp listener start run")
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
if err != nil { if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err return err
} }
n.listenTcp = listener msgChan := make(chan []byte, 16)
for { for {
conn, err := listener.Accept() conn, err := n.listenTcp.Accept()
if err != nil { if err != nil {
continue continue
} }
headedr := make([]byte, 16) var ok bool
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 700)) if n.cmdTCPConn == nil {
c, err := conn.Read(headedr) if conn, ok = n.checkIsTcpControlConn(conn); ok {
if err == nil && c == 16 {
if bytes.Equal(headedr, MSG_CMD_HELLO) {
if n.cmdTCPConn != nil {
n.cmdTCPConn.Close()
}
n.cmdTCPConn = conn n.cmdTCPConn = conn
conn.Write(MSG_CMD_HELLO_REPLY) conn.Write(MSG_CMD_HELLO_REPLY)
// go n.handleTcpControlConn(conn, msgChan)
continue
}
}
if conn, ok = n.checkIsTcpNewConn(conn); ok {
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
n.tcpConnPool <- conn
continue
}
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
go func() {
msgChan <- MSG_NEW_CONN_HELLO
}()
go n.pairNewClientConn(conn)
}
}
func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
select {
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
conn.Close()
return
case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go io.Copy(nconn, conn)
go io.Copy(conn, nconn)
return
}
}
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
go func() {
for {
select {
case data := <-msg:
_, err := conn.Write(data)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
case <-time.After(time.Minute):
_, err := conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
}
}
}()
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
if bytes.Equal(header, MSG_HEARTBEAT) {
n.lastTCPHeart = time.Now().Unix()
}
continue
}
}
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
log := starlog.Std.NewFlag()
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
count, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_CMD_HELLO) {
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
return conn, true
}
}
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
return NewCensorConn(header[:count], conn), false
}
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
if n.cmdTCPConn == nil {
return conn, false
}
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
if remoteIp != newConnIp {
return conn, false
} }
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
read, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
return conn, true
} }
io.ReadFull(conn, headedr) }
return NewCensorConn(header[:read], conn), false
}
type censorConn struct {
reader io.Reader
conn net.Conn
}
func NewCensorConn(header []byte, conn net.Conn) censorConn {
return censorConn{
reader: io.MultiReader(bytes.NewReader(header), conn),
conn: conn,
} }
} }
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c censorConn) Close() error { return c.conn.Close() }
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }

Loading…
Cancel
Save