master
兔子 8 months ago
parent 1276d3b6dd
commit 4074adfcd9

@ -13,3 +13,46 @@ var Cmd = &cobra.Command{
func init() { func init() {
Cmd.AddCommand(netforward.CmdNetforward) Cmd.AddCommand(netforward.CmdNetforward)
} }
var natc NatClient
var nats NatServer
func init() {
CmdNatClient.Flags().StringVarP(&natc.ServiceTarget, "target", "t", "", "forward server target address")
CmdNatClient.Flags().StringVarP(&natc.CmdTarget, "server", "s", "", "nat server command address")
CmdNatClient.Flags().StringVarP(&natc.Passwd, "passwd", "p", "", "password")
CmdNatClient.Flags().BoolVarP(&natc.enableTCP, "enable-tcp", "T", true, "enable tcp forward")
CmdNatClient.Flags().BoolVarP(&natc.enableUDP, "enable-udp", "U", true, "enable udp forward")
CmdNatClient.Flags().IntVarP(&natc.DialTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds")
CmdNatClient.Flags().IntVarP(&natc.UdpTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds")
Cmd.AddCommand(CmdNatClient)
CmdNatServer.Flags().StringVarP(&nats.ListenAddr, "listen", "l", "", "listen address")
CmdNatServer.Flags().StringVarP(&nats.Passwd, "passwd", "p", "", "password")
CmdNatServer.Flags().Int64VarP(&nats.UDPTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds")
CmdNatServer.Flags().Int64VarP(&nats.NetTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds")
CmdNatServer.Flags().BoolVarP(&nats.enableTCP, "enable-tcp", "T", true, "enable tcp forward")
CmdNatServer.Flags().BoolVarP(&nats.enableUDP, "enable-udp", "U", true, "enable udp forward")
Cmd.AddCommand(CmdNatServer)
}
var CmdNatClient = &cobra.Command{
Use: "natc",
Short: "nat client",
Run: func(cmd *cobra.Command, args []string) {
if natc.ServiceTarget == "" || natc.CmdTarget == "" {
cmd.Help()
return
}
natc.Run()
},
}
var CmdNatServer = &cobra.Command{
Use: "nats",
Short: "nat server",
Run: func(cmd *cobra.Command, args []string) {
nats.Run()
},
}

@ -9,11 +9,13 @@ func TestNat(t *testing.T) {
var s = NatServer{ var s = NatServer{
ListenAddr: "0.0.0.0:10020", ListenAddr: "0.0.0.0:10020",
enableTCP: true, enableTCP: true,
enableUDP: true,
} }
var c = NatClient{ var c = NatClient{
ServiceTarget: "139.199.163.65:80", ServiceTarget: "dns.b612.me:521",
CmdTarget: "127.0.0.1:10020", CmdTarget: "127.0.0.1:10020",
enableTCP: true, enableTCP: true,
enableUDP: true,
} }
go s.Run() go s.Run()
go c.Run() go c.Run()

@ -14,14 +14,16 @@ import (
type NatClient struct { type NatClient struct {
mu sync.RWMutex mu sync.RWMutex
cmdTCPConn net.Conn cmdTCPConn net.Conn
cmdUDPConn *net.UDPAddr cmdUDPConn *net.UDPConn
ServiceTarget string ServiceTarget string
CmdTarget string CmdTarget string
tcpAlived bool tcpAlived bool
DialTimeout int DialTimeout int
UdpTimeout int
enableTCP bool enableTCP bool
enableUDP bool enableUDP bool
Passwd string Passwd string
udpAlived bool
stopCtx context.Context stopCtx context.Context
stopFn context.CancelFunc stopFn context.CancelFunc
} }
@ -32,6 +34,12 @@ func (s *NatClient) tcpCmdConn() net.Conn {
return s.cmdTCPConn return s.cmdTCPConn
} }
func (s *NatClient) udpCmdConn() *net.UDPConn {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cmdUDPConn
}
func (s *NatClient) tcpCmdConnAlived() bool { func (s *NatClient) tcpCmdConnAlived() bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@ -44,7 +52,19 @@ func (s *NatClient) setTcpCmdConnAlived(v bool) {
s.tcpAlived = v s.tcpAlived = v
} }
func (s *NatClient) Run() { func (s *NatClient) udpCmdConnAlived() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.udpAlived
}
func (s *NatClient) setUdpCmdConnAlived(v bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.udpAlived = v
}
func (s *NatClient) Run() error {
s.stopCtx, s.stopFn = context.WithCancel(context.Background()) s.stopCtx, s.stopFn = context.WithCancel(context.Background())
if s.DialTimeout == 0 { if s.DialTimeout == 0 {
s.DialTimeout = 10000 s.DialTimeout = 10000
@ -52,9 +72,23 @@ func (s *NatClient) Run() {
if s.Passwd != "" { if s.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16] MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16]
} }
var wg sync.WaitGroup
if s.enableUDP {
wg.Add(1)
go func() {
defer wg.Done()
s.runUdp()
}()
}
if s.enableTCP { if s.enableTCP {
wg.Add(1)
go func() {
defer wg.Done()
s.runTcp() s.runTcp()
}()
} }
wg.Wait()
return nil
} }
func (s *NatClient) runTcp() error { func (s *NatClient) runTcp() error {
@ -87,6 +121,70 @@ func (s *NatClient) runTcp() error {
} }
} }
func (s *NatClient) runUdp() error {
starlog.Noticeln("nat client udp module start run")
if s.UdpTimeout == 0 {
s.UdpTimeout = 600000
}
for {
select {
case <-s.stopCtx.Done():
if s.cmdTCPConn != nil {
s.setUdpCmdConnAlived(false)
s.cmdUDPConn.Close()
return nil
}
case <-time.After(time.Millisecond * 3000):
}
if s.cmdUDPConn != nil && s.udpCmdConnAlived() {
continue
}
rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget)
if err != nil {
starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
continue
}
s.cmdUDPConn, err = net.DialUDP("udp", nil, rmt)
if err != nil {
starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
s.cmdTCPConn = nil
continue
}
starlog.Infoln("dail remote udp cmd server ok,remote:", s.CmdTarget)
s.udpCmdConn().Write(MSG_CMD_HELLO)
s.setUdpCmdConnAlived(true)
go s.handleUdpCmdConn(s.udpCmdConn())
}
}
func (s *NatClient) handleUdpCmdConn(conn *net.UDPConn) {
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
starlog.Infoln("udp cmd server read fail:", err)
conn.Close()
s.setUdpCmdConnAlived(false)
return
}
if bytes.Equal(header, MSG_CMD_HELLO_REPLY) {
continue
}
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
go s.newRemoteUdpConn()
}
if bytes.Equal(header, MSG_HEARTBEAT) {
_, err = conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
s.setUdpCmdConnAlived(false)
return
}
}
}
}
func (s *NatClient) handleTcpCmdConn(conn net.Conn) { func (s *NatClient) handleTcpCmdConn(conn net.Conn) {
for { for {
header := make([]byte, 16) header := make([]byte, 16)
@ -125,14 +223,121 @@ func (s *NatClient) newRemoteTcpConn() {
_, err = nconn.Write(MSG_NEW_CONN_HELLO) _, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil { if err != nil {
nconn.Close() nconn.Close()
log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err) log.Errorf("write new tcp client hello to server %v fail:%v\n", s.CmdTarget, err)
return return
} }
cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout)) cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil { if err != nil {
log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err) log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err)
nconn.Close()
return return
} }
go io.Copy(cconn, nconn) go func() {
go io.Copy(nconn, cconn) for {
data := make([]byte, 8192)
nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := nconn.Read(data)
if err != nil {
starlog.Infoln("read from tcp server fail:", nconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
_, err = cconn.Write(data[:n])
//starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr())
if err != nil {
starlog.Infoln("write to tcp client fail:", cconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
}
}()
go func() {
for {
data := make([]byte, 8192)
cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := cconn.Read(data)
if err != nil {
starlog.Infoln("read from tcp server fail:", cconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
_, err = nconn.Write(data[:n])
if err != nil {
starlog.Infoln("write to tcp client fail:", nconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
}
}()
}
func (s *NatClient) newRemoteUdpConn() {
log := starlog.Std.NewFlag()
starlog.Infoln("recv request,create new udp conn")
rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err)
return
}
nconn, err := net.DialUDP("udp", nil, rmt)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err)
return
}
log.Infof("dail server udp conn %v ok\n", s.CmdTarget)
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil {
nconn.Close()
log.Errorf("write new udp client hello to server %v fail:%v\n", s.CmdTarget, err)
return
}
rmt, err = net.ResolveUDPAddr("udp", s.ServiceTarget)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.ServiceTarget, err)
return
}
cconn, err := net.DialUDP("udp", nil, rmt)
if err != nil {
log.Errorf("dail remote udp conn %v fail:%v\n", s.ServiceTarget, err)
return
}
log.Infof("dail remote udp conn %v ok\n", s.ServiceTarget)
go func() {
for {
data := make([]byte, 8192)
nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := nconn.Read(data)
if err != nil {
starlog.Infoln("read from udp server fail:", err)
return
}
_, err = cconn.Write(data[:n])
//starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr())
if err != nil {
starlog.Infoln("write to udp client fail:", err)
return
}
}
}()
go func() {
for {
data := make([]byte, 8192)
cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := cconn.Read(data)
if err != nil {
starlog.Infoln("read from udp server fail:", err)
return
}
_, err = nconn.Write(data[:n])
if err != nil {
starlog.Infoln("write to udp client fail:", err)
return
}
}
}()
} }

@ -22,6 +22,7 @@ var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014"
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612") var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
// MSG_HEARTBEAT 心跳报文 16byte
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008") var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct { type NatServer struct {
@ -29,6 +30,9 @@ type NatServer struct {
cmdTCPConn net.Conn cmdTCPConn net.Conn
listenTcp net.Listener listenTcp net.Listener
listenUDP *net.UDPConn listenUDP *net.UDPConn
udpConnMap sync.Map
udpPairMap sync.Map
udpCmdAddr *net.UDPAddr
ListenAddr string ListenAddr string
lastTCPHeart int64 lastTCPHeart int64
lastUDPHeart int64 lastUDPHeart int64
@ -37,6 +41,7 @@ type NatServer struct {
UDPTimeout int64 UDPTimeout int64
running int32 running int32
tcpConnPool chan net.Conn tcpConnPool chan net.Conn
udpConnPool chan addionData
stopCtx context.Context stopCtx context.Context
stopFn context.CancelFunc stopFn context.CancelFunc
enableTCP bool enableTCP bool
@ -54,10 +59,22 @@ func (n *NatServer) Run() error {
if n.Passwd != "" { if n.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16] MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
} }
var wg sync.WaitGroup
if n.enableUDP {
wg.Add(1)
go func() {
defer wg.Done()
n.runUdpListen()
}()
}
if n.enableTCP { if n.enableTCP {
go n.runTcpListen() wg.Add(1)
go func() {
defer wg.Done()
n.runTcpListen()
}()
} }
wg.Wait()
return nil return nil
} }
@ -100,6 +117,151 @@ func (n *NatServer) runTcpListen() error {
} }
} }
func (n *NatServer) runUdpListen() error {
var err error
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
starlog.Infoln("nat server udp listener start run")
if n.UDPTimeout == 0 {
n.UDPTimeout = 120
}
n.udpConnPool = make(chan addionData, 128)
udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server udp listener start failed:", err)
return err
}
n.listenUDP, err = net.ListenUDP("udp", udpListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
go func() {
for {
select {
case <-n.stopCtx.Done():
if n.listenUDP != nil {
n.listenUDP.Close()
}
case <-time.After(time.Second * 30):
if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout {
if n.udpCmdAddr != nil {
n.udpCmdAddr = nil
}
}
if n.udpCmdAddr != nil {
n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr)
}
n.udpConnMap.Range(func(key, value interface{}) bool {
if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout {
if taregt, ok := n.udpPairMap.Load(key); ok {
n.udpConnMap.Delete(taregt)
n.udpPairMap.Delete(taregt)
}
n.udpConnMap.Delete(key)
n.udpPairMap.Delete(key)
}
return true
})
}
}
}()
for {
data := make([]byte, 8192)
c, udpAddr, err := n.listenUDP.ReadFromUDP(data)
if err != nil {
continue
}
n.handleUdpData(udpAddr, data[:c])
}
}
type addionData struct {
lastHeartbeat int64
Addr *net.UDPAddr
MsgFrom []byte
}
func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) {
starlog.Infoln("handle udp data from:", addr.String())
if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 {
if bytes.Equal(data[:16], MSG_HEARTBEAT) {
starlog.Infoln("recv udp cmd heartbeat")
n.lastUDPHeart = time.Now().Unix()
}
return
}
if n.udpCmdAddr == nil {
if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) {
starlog.Infof("recv udp cmd hello from %v\n", addr.String())
n.udpCmdAddr = addr
n.lastUDPHeart = time.Now().Unix()
n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr)
return
}
}
if _, ok := n.udpConnMap.Load(addr.IP.String()); ok {
if target, ok := n.udpPairMap.Load(addr.IP.String()); ok {
starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String())
rmt := target.(*net.UDPAddr)
if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok {
n.udpConnMap.Delete(addr.IP.String())
n.udpPairMap.Delete(addr.IP.String())
n.udpPairMap.Delete(rmt.IP.String())
starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String())
return
}
tmp, _ := n.udpConnMap.Load(addr.IP.String())
current := tmp.(addionData)
current.lastHeartbeat = time.Now().Unix()
n.udpConnMap.Store(addr.IP.String(), current)
return
}
}
if len(data) >= 16 {
if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) {
starlog.Infof("recv new udp conn hello from %v\n", addr.String())
if len(data) < 16 {
data = data[16:]
} else {
data = []byte{}
}
n.udpConnMap.Store(addr.IP.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpConnPool <- addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
MsgFrom: data,
}
return
}
}
starlog.Infof("wait pair udp conn %v\n", addr.String())
if n.udpCmdAddr == nil {
starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String())
return
} else {
n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr)
}
go func() {
pairAddr := <-n.udpConnPool
n.udpConnMap.Store(addr.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr)
n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP)
starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String())
if len(pairAddr.MsgFrom) > 0 {
n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr)
}
n.listenUDP.WriteToUDP(data, pairAddr.Addr)
}()
}
func (n *NatServer) pairNewClientConn(conn net.Conn) { func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag() log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String()) log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
@ -110,8 +272,16 @@ func (n *NatServer) pairNewClientConn(conn net.Conn) {
return return
case nconn := <-n.tcpConnPool: case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String()) log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go io.Copy(nconn, conn) go func() {
go io.Copy(conn, nconn) defer nconn.Close()
defer conn.Close()
io.Copy(nconn, conn)
}()
go func() {
defer nconn.Close()
defer conn.Close()
io.Copy(conn, nconn)
}()
return return
} }
} }

Loading…
Cancel
Save