package net import ( "b612.me/starlog" "context" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" ) type NetForward struct { LocalAddr string LocalPort int RemoteURI string EnableTCP bool EnableUDP bool DialTimeout time.Duration UDPTimeout time.Duration stopCtx context.Context stopFn context.CancelFunc running int32 } func (n *NetForward) Close() { n.stopFn() } func (n *NetForward) Run() error { if !atomic.CompareAndSwapInt32(&n.running, 0, 1) { return errors.New("already running") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.DialTimeout == 0 { n.DialTimeout = time.Second * 10 } var wg sync.WaitGroup if n.EnableTCP { wg.Add(1) go func() { defer wg.Done() n.runTCP() }() } if n.EnableUDP { wg.Add(1) go func() { defer wg.Done() n.runUDP() }() } wg.Wait() return nil } func (n *NetForward) runTCP() error { listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) if err != nil { starlog.Errorln("Listening On Tcp Failed:", err) return err } go func() { <-n.stopCtx.Done() listen.Close() }() starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) for { conn, err := listen.Accept() if err != nil { continue } log := starlog.Std.NewFlag() log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String()) go func(conn net.Conn) { rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout) if err != nil { log.Errorf("Dial Remote %s Failed:%v\n", n.RemoteURI, err) conn.Close() return } log.Infof("Connect %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) Copy(rmt, conn) log.Noticef("Connection Closed %s <==> %s", conn.RemoteAddr().String(), n.RemoteURI) }(conn) } } type UDPConn struct { net.Conn listen *net.UDPConn remoteAddr *net.UDPAddr lastbeat int64 } func (u UDPConn) Write(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Write(p) } func (u UDPConn) Read(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Read(p) } func (u UDPConn) Work() { buf := make([]byte, 8192) for { count, err := u.Read(buf) if err != nil { u.Close() u.lastbeat = 0 return } _, err = u.listen.Write(buf[0:count]) if err != nil { u.lastbeat = 0 return } } } func (n *NetForward) runUDP() error { var mu sync.RWMutex udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort)) if err != nil { return err } listen, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) go func() { <-n.stopCtx.Done() listen.Close() }() udpMap := make(map[string]UDPConn) go func() { for { select { case <-n.stopCtx.Done(): return case <-time.After(time.Second * 60): mu.Lock() for k, v := range udpMap { if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat { delete(udpMap, k) starlog.Noticef("Connection Closed %s <==> %s", v.remoteAddr.String(), n.RemoteURI) } } mu.Unlock() } } }() buf := make([]byte, 8192) for { count, rmt, err := listen.ReadFromUDP(buf) if err != nil || rmt.String() == n.RemoteURI { continue } go func(data []byte, rmt *net.UDPAddr) { log := starlog.Std.NewFlag() mu.Lock() addr, ok := udpMap[rmt.String()] if !ok { log.Infof("Accept New UDP Conn from %v\n", rmt.String()) conn, err := net.Dial("udp", n.RemoteURI) if err != nil { log.Errorf("Dial Remote %s Failed:%v\n", n.RemoteURI, err) mu.Unlock() return } addr = UDPConn{ Conn: conn, remoteAddr: rmt, listen: listen, lastbeat: time.Now().Unix(), } udpMap[rmt.String()] = addr go addr.Work() log.Infof("Connect %s <==> %s\n", rmt.String(), n.RemoteURI) } mu.Unlock() _, err := addr.Write(data) if err != nil { mu.Lock() addr.Close() delete(udpMap, addr.remoteAddr.String()) mu.Unlock() log.Noticef("Connection Closed %s <==> %s", rmt.String(), n.RemoteURI) } }(buf[0:count], rmt) } } func Copy(dst, src net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() io.Copy(dst, src) }() go func() { defer wg.Done() io.Copy(src, dst) }() wg.Wait() dst.Close() src.Close() }