diff --git a/function.md b/function.md new file mode 100644 index 0000000..ba38a01 --- /dev/null +++ b/function.md @@ -0,0 +1 @@ +## diff --git a/net/cmd.go b/net/cmd.go index 24499d8..cd99152 100644 --- a/net/cmd.go +++ b/net/cmd.go @@ -13,3 +13,46 @@ var Cmd = &cobra.Command{ func init() { Cmd.AddCommand(netforward.CmdNetforward) } + +var natc NatClient +var nats NatServer + +func init() { + CmdNatClient.Flags().StringVarP(&natc.ServiceTarget, "target", "t", "", "forward server target address") + CmdNatClient.Flags().StringVarP(&natc.CmdTarget, "server", "s", "", "nat server command address") + CmdNatClient.Flags().StringVarP(&natc.Passwd, "passwd", "p", "", "password") + CmdNatClient.Flags().BoolVarP(&natc.enableTCP, "enable-tcp", "T", true, "enable tcp forward") + CmdNatClient.Flags().BoolVarP(&natc.enableUDP, "enable-udp", "U", true, "enable udp forward") + CmdNatClient.Flags().IntVarP(&natc.DialTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds") + CmdNatClient.Flags().IntVarP(&natc.UdpTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds") + Cmd.AddCommand(CmdNatClient) + + CmdNatServer.Flags().StringVarP(&nats.ListenAddr, "listen", "l", "", "listen address") + CmdNatServer.Flags().StringVarP(&nats.Passwd, "passwd", "p", "", "password") + CmdNatServer.Flags().Int64VarP(&nats.UDPTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds") + CmdNatServer.Flags().Int64VarP(&nats.NetTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds") + CmdNatServer.Flags().BoolVarP(&nats.enableTCP, "enable-tcp", "T", true, "enable tcp forward") + CmdNatServer.Flags().BoolVarP(&nats.enableUDP, "enable-udp", "U", true, "enable udp forward") + Cmd.AddCommand(CmdNatServer) +} + +var CmdNatClient = &cobra.Command{ + Use: "natc", + Short: "nat client", + Run: func(cmd *cobra.Command, args []string) { + if natc.ServiceTarget == "" || natc.CmdTarget == "" { + cmd.Help() + return + } + natc.Run() + }, +} + +var CmdNatServer = &cobra.Command{ + Use: "nats", + Short: "nat server", + Run: func(cmd *cobra.Command, args []string) { + + nats.Run() + }, +} diff --git a/net/nat_test.go b/net/nat_test.go index 06b1978..ca9411a 100644 --- a/net/nat_test.go +++ b/net/nat_test.go @@ -9,11 +9,13 @@ func TestNat(t *testing.T) { var s = NatServer{ ListenAddr: "0.0.0.0:10020", enableTCP: true, + enableUDP: true, } var c = NatClient{ - ServiceTarget: "139.199.163.65:80", + ServiceTarget: "dns.b612.me:521", CmdTarget: "127.0.0.1:10020", enableTCP: true, + enableUDP: true, } go s.Run() go c.Run() diff --git a/net/natclient.go b/net/natclient.go index e564159..cd1a9e4 100644 --- a/net/natclient.go +++ b/net/natclient.go @@ -14,14 +14,16 @@ import ( type NatClient struct { mu sync.RWMutex cmdTCPConn net.Conn - cmdUDPConn *net.UDPAddr + cmdUDPConn *net.UDPConn ServiceTarget string CmdTarget string tcpAlived bool DialTimeout int + UdpTimeout int enableTCP bool enableUDP bool Passwd string + udpAlived bool stopCtx context.Context stopFn context.CancelFunc } @@ -32,6 +34,12 @@ func (s *NatClient) tcpCmdConn() net.Conn { return s.cmdTCPConn } +func (s *NatClient) udpCmdConn() *net.UDPConn { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cmdUDPConn +} + func (s *NatClient) tcpCmdConnAlived() bool { s.mu.RLock() defer s.mu.RUnlock() @@ -44,7 +52,19 @@ func (s *NatClient) setTcpCmdConnAlived(v bool) { s.tcpAlived = v } -func (s *NatClient) Run() { +func (s *NatClient) udpCmdConnAlived() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.udpAlived +} + +func (s *NatClient) setUdpCmdConnAlived(v bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.udpAlived = v +} + +func (s *NatClient) Run() error { s.stopCtx, s.stopFn = context.WithCancel(context.Background()) if s.DialTimeout == 0 { s.DialTimeout = 10000 @@ -52,9 +72,23 @@ func (s *NatClient) Run() { if s.Passwd != "" { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16] } + var wg sync.WaitGroup + if s.enableUDP { + wg.Add(1) + go func() { + defer wg.Done() + s.runUdp() + }() + } if s.enableTCP { - s.runTcp() + wg.Add(1) + go func() { + defer wg.Done() + s.runTcp() + }() } + wg.Wait() + return nil } func (s *NatClient) runTcp() error { @@ -87,6 +121,70 @@ func (s *NatClient) runTcp() error { } } +func (s *NatClient) runUdp() error { + starlog.Noticeln("nat client udp module start run") + if s.UdpTimeout == 0 { + s.UdpTimeout = 600000 + } + for { + select { + case <-s.stopCtx.Done(): + if s.cmdTCPConn != nil { + s.setUdpCmdConnAlived(false) + s.cmdUDPConn.Close() + return nil + } + case <-time.After(time.Millisecond * 3000): + } + if s.cmdUDPConn != nil && s.udpCmdConnAlived() { + continue + } + rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget) + if err != nil { + starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err) + time.Sleep(time.Second * 2) + continue + } + s.cmdUDPConn, err = net.DialUDP("udp", nil, rmt) + if err != nil { + starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err) + time.Sleep(time.Second * 2) + s.cmdTCPConn = nil + continue + } + starlog.Infoln("dail remote udp cmd server ok,remote:", s.CmdTarget) + s.udpCmdConn().Write(MSG_CMD_HELLO) + s.setUdpCmdConnAlived(true) + go s.handleUdpCmdConn(s.udpCmdConn()) + } +} +func (s *NatClient) handleUdpCmdConn(conn *net.UDPConn) { + for { + header := make([]byte, 16) + _, err := io.ReadFull(conn, header) + if err != nil { + starlog.Infoln("udp cmd server read fail:", err) + conn.Close() + s.setUdpCmdConnAlived(false) + return + } + if bytes.Equal(header, MSG_CMD_HELLO_REPLY) { + continue + } + if bytes.Equal(header, MSG_NEW_CONN_HELLO) { + go s.newRemoteUdpConn() + } + if bytes.Equal(header, MSG_HEARTBEAT) { + _, err = conn.Write(MSG_HEARTBEAT) + if err != nil { + conn.Close() + s.setUdpCmdConnAlived(false) + return + } + } + } +} + func (s *NatClient) handleTcpCmdConn(conn net.Conn) { for { header := make([]byte, 16) @@ -125,14 +223,121 @@ func (s *NatClient) newRemoteTcpConn() { _, err = nconn.Write(MSG_NEW_CONN_HELLO) if err != nil { nconn.Close() - log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err) + log.Errorf("write new tcp client hello to server %v fail:%v\n", s.CmdTarget, err) return } cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout)) if err != nil { log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err) + nconn.Close() + return + } + go func() { + for { + data := make([]byte, 8192) + nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout))) + n, err := nconn.Read(data) + if err != nil { + starlog.Infoln("read from tcp server fail:", nconn.RemoteAddr(), err) + nconn.Close() + cconn.Close() + return + } + _, err = cconn.Write(data[:n]) + //starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr()) + if err != nil { + starlog.Infoln("write to tcp client fail:", cconn.RemoteAddr(), err) + nconn.Close() + cconn.Close() + return + } + } + }() + go func() { + for { + data := make([]byte, 8192) + cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout))) + n, err := cconn.Read(data) + if err != nil { + starlog.Infoln("read from tcp server fail:", cconn.RemoteAddr(), err) + nconn.Close() + cconn.Close() + return + } + _, err = nconn.Write(data[:n]) + if err != nil { + starlog.Infoln("write to tcp client fail:", nconn.RemoteAddr(), err) + nconn.Close() + cconn.Close() + return + } + } + }() +} + +func (s *NatClient) newRemoteUdpConn() { + log := starlog.Std.NewFlag() + starlog.Infoln("recv request,create new udp conn") + rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget) + if err != nil { + log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err) + return + } + nconn, err := net.DialUDP("udp", nil, rmt) + if err != nil { + log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err) return } - go io.Copy(cconn, nconn) - go io.Copy(nconn, cconn) + log.Infof("dail server udp conn %v ok\n", s.CmdTarget) + _, err = nconn.Write(MSG_NEW_CONN_HELLO) + if err != nil { + nconn.Close() + log.Errorf("write new udp client hello to server %v fail:%v\n", s.CmdTarget, err) + return + } + + rmt, err = net.ResolveUDPAddr("udp", s.ServiceTarget) + if err != nil { + log.Errorf("dail server udp conn %v fail:%v\n", s.ServiceTarget, err) + return + } + cconn, err := net.DialUDP("udp", nil, rmt) + if err != nil { + log.Errorf("dail remote udp conn %v fail:%v\n", s.ServiceTarget, err) + return + } + log.Infof("dail remote udp conn %v ok\n", s.ServiceTarget) + go func() { + for { + data := make([]byte, 8192) + nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout))) + n, err := nconn.Read(data) + if err != nil { + starlog.Infoln("read from udp server fail:", err) + return + } + _, err = cconn.Write(data[:n]) + //starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr()) + if err != nil { + starlog.Infoln("write to udp client fail:", err) + return + } + } + }() + go func() { + for { + data := make([]byte, 8192) + cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout))) + n, err := cconn.Read(data) + if err != nil { + starlog.Infoln("read from udp server fail:", err) + return + } + _, err = nconn.Write(data[:n]) + if err != nil { + starlog.Infoln("write to udp client fail:", err) + return + } + } + }() } diff --git a/net/natserver.go b/net/natserver.go index fe0c372..cc4115d 100644 --- a/net/natserver.go +++ b/net/natserver.go @@ -22,6 +22,7 @@ var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014" // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612") +// MSG_HEARTBEAT 心跳报文 16byte var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008") type NatServer struct { @@ -29,6 +30,9 @@ type NatServer struct { cmdTCPConn net.Conn listenTcp net.Listener listenUDP *net.UDPConn + udpConnMap sync.Map + udpPairMap sync.Map + udpCmdAddr *net.UDPAddr ListenAddr string lastTCPHeart int64 lastUDPHeart int64 @@ -37,6 +41,7 @@ type NatServer struct { UDPTimeout int64 running int32 tcpConnPool chan net.Conn + udpConnPool chan addionData stopCtx context.Context stopFn context.CancelFunc enableTCP bool @@ -54,10 +59,22 @@ func (n *NatServer) Run() error { if n.Passwd != "" { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16] } - + var wg sync.WaitGroup + if n.enableUDP { + wg.Add(1) + go func() { + defer wg.Done() + n.runUdpListen() + }() + } if n.enableTCP { - go n.runTcpListen() + wg.Add(1) + go func() { + defer wg.Done() + n.runTcpListen() + }() } + wg.Wait() return nil } @@ -100,6 +117,151 @@ func (n *NatServer) runTcpListen() error { } } +func (n *NatServer) runUdpListen() error { + var err error + atomic.AddInt32(&n.running, 1) + defer atomic.AddInt32(&n.running, -1) + starlog.Infoln("nat server udp listener start run") + if n.UDPTimeout == 0 { + n.UDPTimeout = 120 + } + n.udpConnPool = make(chan addionData, 128) + udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr) + if err != nil { + starlog.Errorln("nat server udp listener start failed:", err) + return err + } + n.listenUDP, err = net.ListenUDP("udp", udpListenAddr) + if err != nil { + starlog.Errorln("nat server tcp listener start failed:", err) + return err + } + go func() { + for { + select { + case <-n.stopCtx.Done(): + if n.listenUDP != nil { + n.listenUDP.Close() + } + case <-time.After(time.Second * 30): + if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout { + if n.udpCmdAddr != nil { + n.udpCmdAddr = nil + } + } + if n.udpCmdAddr != nil { + n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr) + } + n.udpConnMap.Range(func(key, value interface{}) bool { + if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout { + if taregt, ok := n.udpPairMap.Load(key); ok { + n.udpConnMap.Delete(taregt) + n.udpPairMap.Delete(taregt) + } + n.udpConnMap.Delete(key) + n.udpPairMap.Delete(key) + } + return true + }) + } + } + }() + for { + data := make([]byte, 8192) + c, udpAddr, err := n.listenUDP.ReadFromUDP(data) + if err != nil { + continue + } + n.handleUdpData(udpAddr, data[:c]) + } +} + +type addionData struct { + lastHeartbeat int64 + Addr *net.UDPAddr + MsgFrom []byte +} + +func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) { + starlog.Infoln("handle udp data from:", addr.String()) + if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 { + if bytes.Equal(data[:16], MSG_HEARTBEAT) { + starlog.Infoln("recv udp cmd heartbeat") + n.lastUDPHeart = time.Now().Unix() + } + return + } + if n.udpCmdAddr == nil { + if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) { + starlog.Infof("recv udp cmd hello from %v\n", addr.String()) + n.udpCmdAddr = addr + n.lastUDPHeart = time.Now().Unix() + n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr) + return + } + } + if _, ok := n.udpConnMap.Load(addr.IP.String()); ok { + if target, ok := n.udpPairMap.Load(addr.IP.String()); ok { + starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String()) + rmt := target.(*net.UDPAddr) + if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok { + n.udpConnMap.Delete(addr.IP.String()) + n.udpPairMap.Delete(addr.IP.String()) + n.udpPairMap.Delete(rmt.IP.String()) + starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String()) + return + } + tmp, _ := n.udpConnMap.Load(addr.IP.String()) + current := tmp.(addionData) + current.lastHeartbeat = time.Now().Unix() + n.udpConnMap.Store(addr.IP.String(), current) + return + } + } + if len(data) >= 16 { + if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) { + starlog.Infof("recv new udp conn hello from %v\n", addr.String()) + if len(data) < 16 { + data = data[16:] + } else { + data = []byte{} + } + n.udpConnMap.Store(addr.IP.String(), addionData{ + lastHeartbeat: time.Now().Unix(), + Addr: addr, + }) + n.udpConnPool <- addionData{ + lastHeartbeat: time.Now().Unix(), + Addr: addr, + MsgFrom: data, + } + return + } + } + starlog.Infof("wait pair udp conn %v\n", addr.String()) + if n.udpCmdAddr == nil { + starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String()) + return + } else { + n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr) + } + go func() { + pairAddr := <-n.udpConnPool + n.udpConnMap.Store(addr.String(), addionData{ + lastHeartbeat: time.Now().Unix(), + Addr: addr, + }) + n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr) + n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP) + starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String()) + if len(pairAddr.MsgFrom) > 0 { + n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr) + } + n.listenUDP.WriteToUDP(data, pairAddr.Addr) + }() + +} + func (n *NatServer) pairNewClientConn(conn net.Conn) { log := starlog.Std.NewFlag() log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String()) @@ -110,8 +272,16 @@ func (n *NatServer) pairNewClientConn(conn net.Conn) { return case nconn := <-n.tcpConnPool: log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String()) - go io.Copy(nconn, conn) - go io.Copy(conn, nconn) + go func() { + defer nconn.Close() + defer conn.Close() + io.Copy(nconn, conn) + }() + go func() { + defer nconn.Close() + defer conn.Close() + io.Copy(conn, nconn) + }() return } }