package net import ( "b612.me/starlog" "bytes" "context" "crypto/sha256" "io" "net" "sync" "time" ) type NatClient struct { mu sync.RWMutex cmdTCPConn net.Conn cmdUDPConn *net.UDPAddr ServiceTarget string CmdTarget string tcpAlived bool DialTimeout int enableTCP bool enableUDP bool Passwd string stopCtx context.Context stopFn context.CancelFunc } func (s *NatClient) tcpCmdConn() net.Conn { s.mu.RLock() defer s.mu.RUnlock() return s.cmdTCPConn } func (s *NatClient) tcpCmdConnAlived() bool { s.mu.RLock() defer s.mu.RUnlock() return s.tcpAlived } func (s *NatClient) setTcpCmdConnAlived(v bool) { s.mu.Lock() defer s.mu.Unlock() s.tcpAlived = v } func (s *NatClient) Run() { s.stopCtx, s.stopFn = context.WithCancel(context.Background()) if s.DialTimeout == 0 { s.DialTimeout = 10000 } if s.Passwd != "" { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16] } if s.enableTCP { s.runTcp() } } func (s *NatClient) runTcp() error { var err error starlog.Noticeln("nat client tcp module start run") for { select { case <-s.stopCtx.Done(): if s.cmdTCPConn != nil { s.setTcpCmdConnAlived(false) s.cmdTCPConn.Close() return nil } case <-time.After(time.Millisecond * 1500): } if s.cmdTCPConn != nil && s.tcpCmdConnAlived() { continue } s.cmdTCPConn, err = net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout)) if err != nil { starlog.Errorf("dail remote tcp cmd server %v fail:%v;will retry\n", s.CmdTarget, err) time.Sleep(time.Second * 2) s.cmdTCPConn = nil continue } starlog.Infoln("dail remote tcp cmd server ok,remote:", s.CmdTarget) s.tcpCmdConn().Write(MSG_CMD_HELLO) s.setTcpCmdConnAlived(true) go s.handleTcpCmdConn(s.tcpCmdConn()) } } func (s *NatClient) handleTcpCmdConn(conn net.Conn) { for { header := make([]byte, 16) _, err := io.ReadFull(conn, header) if err != nil { starlog.Infoln("tcp cmd server read fail:", err) conn.Close() s.setTcpCmdConnAlived(false) return } if bytes.Equal(header, MSG_CMD_HELLO_REPLY) { continue } if bytes.Equal(header, MSG_NEW_CONN_HELLO) { go s.newRemoteTcpConn() } if bytes.Equal(header, MSG_HEARTBEAT) { _, err = conn.Write(MSG_HEARTBEAT) if err != nil { conn.Close() s.setTcpCmdConnAlived(false) return } } } } func (s *NatClient) newRemoteTcpConn() { log := starlog.Std.NewFlag() starlog.Infoln("recv request,create new tcp conn") nconn, err := net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout)) if err != nil { log.Errorf("dail server tcp conn %v fail:%v\n", s.CmdTarget, err) return } _, err = nconn.Write(MSG_NEW_CONN_HELLO) if err != nil { nconn.Close() log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err) return } cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout)) if err != nil { log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err) return } go io.Copy(cconn, nconn) go io.Copy(nconn, cconn) }