package netforward import ( "b612.me/stario" "b612.me/starlog" "b612.me/starmap" "context" "errors" "fmt" "io" "net" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" ) type NetForward struct { LocalAddr string LocalPort int RemoteURI string EnableTCP bool EnableUDP bool DelayMilSec int DelayToward int StdinMode bool IgnoreEof bool DialTimeout time.Duration UDPTimeout time.Duration stopCtx context.Context stopFn context.CancelFunc running int32 UdpHooks map[string]*starmap.StarStack KeepAlivePeriod int KeepAliveIdel int KeepAliveCount int UserTimeout int UsingKeepAlive bool Verbose bool udpListener *net.UDPConn } func (n *NetForward) UdpListener() *net.UDPConn { return n.udpListener } func (n *NetForward) Close() { n.stopFn() } func (n *NetForward) Status() int32 { return atomic.LoadInt32(&n.running) } func (n *NetForward) Run() error { if n.running > 0 { starlog.Errorln("already running") return errors.New("already running") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.DialTimeout == 0 { n.DialTimeout = time.Second * 5 } if n.StdinMode { go func() { for { cmd := strings.TrimSpace(stario.MessageBox("", "").MustString()) for strings.Contains(cmd, " ") { cmd = strings.Replace(cmd, " ", " ", -1) } starlog.Debugf("Recv Command %s\n", cmd) cmds := strings.Split(cmd, " ") if len(cmds) < 3 { starlog.Errorln("Invalid Command", cmd) continue } switch cmds[0] + cmds[1] { case "setremote": n.RemoteURI = cmds[2] starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI) case "setdelaytoward": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Toward Value", cmds[2]) continue } n.DelayToward = tmp starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward) case "setdelay": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Value", cmds[2]) continue } n.DelayMilSec = tmp starlog.Noticef("Delay Set to %d\n", n.DelayMilSec) case "setdialtimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Dial Timeout Value", cmds[2]) continue } n.DialTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout) case "setudptimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid UDP Timeout Value", cmds[2]) continue } n.UDPTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout) case "setstdin": if cmds[2] == "off" { n.StdinMode = false starlog.Noticef("Stdin Mode Off\n") return } } } }() } if n.EnableTCP { go n.runTCP() } if n.EnableUDP { go n.runUDP() } return nil } func (n *NetForward) runTCP() error { atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) cfg := net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { return c.Control(SetReUseAddr) }, } listen, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) if err != nil { starlog.Errorln("Listening On Tcp Failed:", err) return err } go func() { <-n.stopCtx.Done() listen.Close() }() starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) for { select { case <-n.stopCtx.Done(): return nil default: } conn, err := listen.Accept() if err != nil { continue } log := starlog.Std.NewFlag() log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String()) if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { log.Infof("Delay %d ms\n", n.DelayMilSec) time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } err = SetTcpInfo(conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) conn.Close() continue } go func(conn net.Conn) { rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout) if err != nil { log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) conn.Close() return } err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) rmt.Close() return } log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String()) n.copy(rmt, conn) log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) conn.Close() rmt.Close() }(conn) } } type UDPConn struct { net.Conn listen *net.UDPConn remoteAddr *net.UDPAddr lastbeat int64 } func (u *UDPConn) Write(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Write(p) } func (u *UDPConn) Read(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Read(p) } func (u *UDPConn) Work(delay int, verbose bool) { buf := make([]byte, 8192) for { if delay > 0 { time.Sleep(time.Millisecond * time.Duration(delay)) } count, err := u.Read(buf) if err != nil { u.Close() u.lastbeat = 0 return } if verbose { fmt.Printf("U %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), u.Conn.RemoteAddr().String(), u.remoteAddr.String(), buf[0:count]) } _, err = u.listen.WriteTo(buf[0:count], u.remoteAddr) if err != nil { u.lastbeat = 0 return } } } func (n *NetForward) runUDP() error { var mu sync.RWMutex atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort)) if err != nil { return err } listen, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } n.udpListener = listen starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) go func() { <-n.stopCtx.Done() listen.Close() }() udpMap := make(map[string]*UDPConn) go func() { for { select { case <-n.stopCtx.Done(): return case <-time.After(time.Second * 60): mu.Lock() for k, v := range udpMap { if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat { delete(udpMap, k) starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI) } } mu.Unlock() } } }() buf := make([]byte, 8192) for { select { case <-n.stopCtx.Done(): return nil default: } count, rmt, err := listen.ReadFromUDP(buf) if err != nil || rmt.String() == n.RemoteURI { continue } { //hooks if n.UdpHooks != nil { if m, ok := n.UdpHooks[rmt.String()]; ok { if m.Free() > 0 { if n.Verbose { starlog.Noticef("Hooked UDP Data %s ==> %s %X\n", rmt.String(), n.RemoteURI, buf[0:count]) } else { starlog.Noticef("Hooked UDP Data %s ==> %s\n", rmt.String(), n.RemoteURI) } m.Push(buf[0:count]) continue } } } } go func(data []byte, rmt *net.UDPAddr) { log := starlog.Std.NewFlag() mu.Lock() addr, ok := udpMap[rmt.String()] if !ok { log.Infof("Accept New UDP Conn from %v\n", rmt.String()) conn, err := net.Dial("udp", n.RemoteURI) if err != nil { log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) mu.Unlock() return } addr = &UDPConn{ Conn: conn, remoteAddr: rmt, listen: listen, lastbeat: time.Now().Unix(), } udpMap[rmt.String()] = addr go addr.Work(n.DelayMilSec, n.Verbose) log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI) } mu.Unlock() if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } if n.Verbose { fmt.Printf("T %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), rmt.String(), n.RemoteURI, data) } _, err := addr.Write(data) if err != nil { mu.Lock() addr.Close() delete(udpMap, addr.remoteAddr.String()) mu.Unlock() log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI) } }(buf[0:count], rmt) } } func (n *NetForward) showVerbose(toward, src, dst string, data []byte) { if n.Verbose { fmt.Printf("%s %v Recv Data %s ==> %s %X\n", toward, time.Now().Format("2006-01-02 15:04:05"), src, dst, data) } } func (n *NetForward) copy(dst, src net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := src.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } dst.Close() src.Close() return } n.showVerbose("T", src.RemoteAddr().String(), dst.RemoteAddr().String(), bufsize[:count]) _, err = dst.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := dst.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } src.Close() dst.Close() return } n.showVerbose("U", dst.RemoteAddr().String(), src.RemoteAddr().String(), bufsize[:count]) _, err = src.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() wg.Wait() }